diff --git a/README.md b/README.md index 65925810..ec0c42be 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,10 @@ Dev tunnels allows developers to securely expose local web services to the Inter |---|---|---|---|---|---| | Management API | ✅ | ✅ | ✅ | ✅ | ✅ | | Tunnel Client Connections | ✅ | ✅ | ✅ | ✅ | ✅ | -| Tunnel Host Connections | ✅ | ✅ | ❌ | ❌ | ✅ | -| Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | +| Tunnel Host Connections | ✅ | ✅ | ❌ | ✅ | ✅ | +| Reconnection | ✅ | ✅ | ❌ | ✅ | ❌ | | SSH-level Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | -| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ❌ | ❌ | +| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ✅ | ❌ | | Ssh Keep-alive | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ - Supported diff --git a/go/tunnels/connection_status.go b/go/tunnels/connection_status.go new file mode 100644 index 00000000..cccbd777 --- /dev/null +++ b/go/tunnels/connection_status.go @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +// ConnectionStatus represents the connection state of a tunnel host. +type ConnectionStatus int + +const ( + // ConnectionStatusNone indicates no connection has been made. + ConnectionStatusNone ConnectionStatus = iota + + // ConnectionStatusConnecting indicates a connection is in progress. + ConnectionStatusConnecting + + // ConnectionStatusConnected indicates the host is connected. + ConnectionStatusConnected + + // ConnectionStatusDisconnected indicates the host has disconnected. + ConnectionStatusDisconnected +) + +// String returns the string representation of the connection status. +func (s ConnectionStatus) String() string { + switch s { + case ConnectionStatusNone: + return "None" + case ConnectionStatusConnecting: + return "Connecting" + case ConnectionStatusConnected: + return "Connected" + case ConnectionStatusDisconnected: + return "Disconnected" + default: + return "Unknown" + } +} diff --git a/go/tunnels/examples/getting_started.md b/go/tunnels/examples/getting_started.md index 5ae37db7..23c23bdd 100644 --- a/go/tunnels/examples/getting_started.md +++ b/go/tunnels/examples/getting_started.md @@ -1,9 +1,55 @@ # Getting Started -To use the example you must do the following setup first: +## Client Example + +To use the client example you must do the following setup first: 1. Create a tunnel on the CLI or another SDK and put the tunnelId and clusterId in the constants section of example.go 2. Create ports on the tunnel that you want to be hosted 3. Get a tunnels access token and paste it in the return value of getAccessToken() in example.go or set it as the TUNNELS_TOKEN environment variable 4. Start hosting the tunnel either on the CLI or on a different SDK -5. Run example.go with the command `go run example.go` \ No newline at end of file +5. Run example.go with the command `go run example.go` + +## Host Example + +To use the host example: + +1. Create a tunnel on the CLI or management API +2. Get a host access token for the tunnel and set it as the TUNNELS_TOKEN environment variable +3. Set the `hostTunnelID` and `hostClusterID` constants in host/host_example.go +4. Set the `localPort` constant to the local TCP port you want to forward (default: 8080) +5. Start a local service on that port (e.g., `python -m http.server 8080`) +6. Run the host: `cd host && TUNNELS_TOKEN= go run host_example.go` + +The host will: +- Connect to the relay and register an endpoint +- Forward the specified local port to remote clients +- Automatically reconnect if the relay connection drops +- Shut down gracefully on Ctrl+C (unregisters the endpoint) + +### Host API Overview + +```go +// Create a host +host, err := tunnels.NewHost(logger, manager) + +// Optional: enable reconnection and status callbacks +host.EnableReconnect = true +host.ConnectionStatusChanged = func(prev, curr tunnels.ConnectionStatus) { ... } + +// Connect to the relay +host.Connect(ctx, tunnel) + +// Add/remove forwarded ports dynamically +host.AddPort(ctx, &tunnels.TunnelPort{PortNumber: 8080}) +host.RemovePort(ctx, 8080) + +// Sync ports with the management service +host.RefreshPorts(ctx) + +// Block until disconnected (reconnects automatically if enabled) +host.Wait() + +// Graceful shutdown +host.Close() +``` diff --git a/go/tunnels/examples/host/host_example.go b/go/tunnels/examples/host/host_example.go new file mode 100644 index 00000000..17618cd2 --- /dev/null +++ b/go/tunnels/examples/host/host_example.go @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// This example demonstrates how to use the Go SDK to host a tunnel, +// forwarding a local TCP port to remote clients through the relay. +// +// Prerequisites: +// - A tunnel created via the CLI or management API +// - A host access token for the tunnel +// +// Usage: +// TUNNELS_TOKEN= go run host_example.go + +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "os/signal" + "syscall" + + tunnels "github.com/microsoft/dev-tunnels/go/tunnels" +) + +// Set the tunnel ID and cluster ID for the tunnel you want to host. +const ( + hostTunnelID = "" + hostClusterID = "usw2" + + // The local port to forward through the tunnel. + localPort = 8080 +) + +var ( + hostURI = tunnels.ServiceProperties.ServiceURI + hostUserAgent = []tunnels.UserAgent{{Name: "Tunnels-Go-SDK-Host-Example", Version: "0.0.1"}} +) + +// getHostAccessToken returns the host access token from the TUNNELS_TOKEN +// environment variable. +func getHostAccessToken() string { + if token := os.Getenv("TUNNELS_TOKEN"); token != "" { + return token + } + return "" +} + +func main() { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + logger := log.New(os.Stdout, "[host] ", log.LstdFlags) + + parsedURL, err := url.Parse(hostURI) + if err != nil { + logger.Fatalf("Failed to parse service URI: %v", err) + } + + // Create management client. + mgr, err := tunnels.NewManager(hostUserAgent, getHostAccessToken, parsedURL, nil, "2023-09-27-preview") + if err != nil { + logger.Fatalf("Failed to create manager: %v", err) + } + + // Fetch the tunnel with a host access token. + tunnel := &tunnels.Tunnel{ + TunnelID: hostTunnelID, + ClusterID: hostClusterID, + } + options := &tunnels.TunnelRequestOptions{ + IncludePorts: true, + TokenScopes: []tunnels.TunnelAccessScope{"host"}, + } + + tunnel, err = mgr.GetTunnel(ctx, tunnel, options) + if err != nil { + logger.Fatalf("Failed to get tunnel: %v", err) + } + logger.Printf("Got tunnel: %s", tunnel.TunnelID) + + // Create the host. + host, err := tunnels.NewHost(logger, mgr) + if err != nil { + logger.Fatalf("Failed to create host: %v", err) + } + + // Optional: enable automatic reconnection on relay disconnect. + host.EnableReconnect = true + + // Optional: log connection status changes. + host.ConnectionStatusChanged = func(prev, curr tunnels.ConnectionStatus) { + logger.Printf("Connection status: %v -> %v", prev, curr) + } + + // Connect to the relay. + if err := host.Connect(ctx, tunnel); err != nil { + logger.Fatalf("Failed to connect: %v", err) + } + logger.Printf("Connected to relay") + + // Add a port to forward. This registers the port with the management API + // and notifies any connected clients via SSH tcpip-forward. + port := &tunnels.TunnelPort{PortNumber: localPort} + if err := host.AddPort(ctx, port); err != nil { + logger.Fatalf("Failed to add port: %v", err) + } + logger.Printf("Forwarding local port %d", localPort) + + // Wait for the relay connection (blocks until disconnect or signal). + // With EnableReconnect=true, this will automatically reconnect on drops. + go func() { + if err := host.Wait(); err != nil { + logger.Printf("Relay connection ended: %v", err) + } + }() + + // Wait for interrupt signal. + <-ctx.Done() + logger.Printf("Shutting down...") + + // Close gracefully: closes the SSH session and unregisters the endpoint. + if err := host.Close(); err != nil { + logger.Printf("Close error: %v", err) + } + logger.Printf("Host shut down") + + fmt.Println("Done.") +} diff --git a/go/tunnels/examples/host_e2e_test.go b/go/tunnels/examples/host_e2e_test.go new file mode 100644 index 00000000..d23ac114 --- /dev/null +++ b/go/tunnels/examples/host_e2e_test.go @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// E2E tests for the Host API against the live dev-tunnels relay service. +// These tests are skipped in short mode (go test -short). +// +// To run manually: +// +// TUNNEL_AUTH_TOKEN= go test -v -run TestHostE2ELiveRelay +// TUNNEL_AUTH_TOKEN= go test -v -run TestHostAndClientE2ELiveRelay +// +// The auth token must be a valid Azure AD or GitHub token for the dev-tunnels service. +// You can obtain one via the VS Code dev tunnels extension or the `devtunnel` CLI: +// +// devtunnel user login +// devtunnel token +package main + +import ( + "context" + "io" + "log" + "net" + "net/url" + "os" + "strings" + "testing" + "time" + + tunnels "github.com/microsoft/dev-tunnels/go/tunnels" +) + +func getTestAuthToken() string { + if token := os.Getenv("TUNNEL_AUTH_TOKEN"); token != "" { + return token + } + return "" +} + +// formatAuthToken ensures the token has the correct scheme prefix. +// GitHub tokens (ghu_) need "github " prefix, AAD tokens need "Bearer " prefix. +func formatAuthToken(token string) string { + if strings.HasPrefix(token, "github ") || strings.HasPrefix(token, "Bearer ") || strings.HasPrefix(token, "Tunnel ") { + return token + } + if strings.HasPrefix(token, "ghu_") || strings.HasPrefix(token, "gho_") { + return "github " + token + } + return "Bearer " + token +} + +func newTestManagerForE2E(t *testing.T) *tunnels.Manager { + t.Helper() + token := getTestAuthToken() + if token == "" { + t.Skip("TUNNEL_AUTH_TOKEN not set") + } + + serviceURL, err := url.Parse(tunnels.ServiceProperties.ServiceURI) + if err != nil { + t.Fatalf("failed to parse service URL: %v", err) + } + + userAgents := []tunnels.UserAgent{{Name: "Tunnels-Go-SDK-E2E-Test", Version: "0.0.1"}} + formattedToken := formatAuthToken(token) + mgr, err := tunnels.NewManager(userAgents, func() string { return formattedToken }, serviceURL, nil, "2023-09-27-preview") + if err != nil { + t.Fatalf("failed to create manager: %v", err) + } + return mgr +} + +// TestHostE2ELiveRelay validates host lifecycle against the real relay service. +// Skipped in short mode. Requires TUNNEL_AUTH_TOKEN env var. +func TestHostE2ELiveRelay(t *testing.T) { + if testing.Short() { + t.Skip("skipping live relay test in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + mgr := newTestManagerForE2E(t) + logger := log.New(os.Stdout, "e2e-host: ", log.LstdFlags) + + // 1. Create a tunnel. + tunnel, err := mgr.CreateTunnel(ctx, &tunnels.Tunnel{}, nil) + if err != nil { + t.Fatalf("CreateTunnel failed: %v", err) + } + t.Logf("Created tunnel: %s (cluster: %s)", tunnel.TunnelID, tunnel.ClusterID) + + // Ensure cleanup. + defer func() { + cleanupCtx := context.Background() + if err := mgr.DeleteTunnel(cleanupCtx, tunnel, nil); err != nil { + t.Logf("Warning: DeleteTunnel failed: %v", err) + } + }() + + // Request a host access token. + tokenOptions := &tunnels.TunnelRequestOptions{ + TokenScopes: []tunnels.TunnelAccessScope{tunnels.TunnelAccessScopeHost}, + } + tunnel, err = mgr.GetTunnel(ctx, tunnel, tokenOptions) + if err != nil { + t.Fatalf("GetTunnel (with host token) failed: %v", err) + } + + // 2. Create a Host and connect. + host, err := tunnels.NewHost(logger, mgr) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + t.Log("Host connected to relay") + + // 3. Start a local echo server and add the port. + echoListener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + defer echoListener.Close() + + echoPort := uint16(echoListener.Addr().(*net.TCPAddr).Port) + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + port := &tunnels.TunnelPort{PortNumber: echoPort} + if err := host.AddPort(ctx, port); err != nil { + t.Fatalf("Host.AddPort failed: %v", err) + } + t.Logf("Added port %d", echoPort) + + // 4. Verify the endpoint was registered. + verifyTunnel, err := mgr.GetTunnel(ctx, tunnel, nil) + if err != nil { + t.Fatalf("GetTunnel (verify) failed: %v", err) + } + + if len(verifyTunnel.Endpoints) == 0 { + t.Fatal("expected at least one endpoint after host connect") + } + t.Logf("Verified %d endpoint(s) registered", len(verifyTunnel.Endpoints)) + + // 5. Close the host and verify cleanup. + if err := host.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + t.Log("Host closed successfully") +} + +// TestHostAndClientE2ELiveRelay validates the full tunnel lifecycle with both +// a host and a client connecting through the live relay service. +// Skipped in short mode. Requires TUNNEL_AUTH_TOKEN env var. +func TestHostAndClientE2ELiveRelay(t *testing.T) { + if testing.Short() { + t.Skip("skipping live relay test in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + mgr := newTestManagerForE2E(t) + logger := log.New(os.Stdout, "e2e-full: ", log.LstdFlags) + + // 1. Create a tunnel with host and connect access tokens. + tunnel, err := mgr.CreateTunnel(ctx, &tunnels.Tunnel{}, nil) + if err != nil { + t.Fatalf("CreateTunnel failed: %v", err) + } + t.Logf("Created tunnel: %s (cluster: %s)", tunnel.TunnelID, tunnel.ClusterID) + + defer func() { + cleanupCtx := context.Background() + if err := mgr.DeleteTunnel(cleanupCtx, tunnel, nil); err != nil { + t.Logf("Warning: DeleteTunnel failed: %v", err) + } + }() + + // Request both host and connect tokens. + tokenOptions := &tunnels.TunnelRequestOptions{ + TokenScopes: []tunnels.TunnelAccessScope{ + tunnels.TunnelAccessScopeHost, + tunnels.TunnelAccessScopeConnect, + }, + } + tunnel, err = mgr.GetTunnel(ctx, tunnel, tokenOptions) + if err != nil { + t.Fatalf("GetTunnel (with tokens) failed: %v", err) + } + + // 2. Start a local echo server. + echoListener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + defer echoListener.Close() + + echoPort := uint16(echoListener.Addr().(*net.TCPAddr).Port) + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + // 3. Host connects and adds the echo port. + host, err := tunnels.NewHost(logger, mgr) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + t.Log("Host connected") + + port := &tunnels.TunnelPort{PortNumber: echoPort} + if err := host.AddPort(ctx, port); err != nil { + t.Fatalf("Host.AddPort failed: %v", err) + } + t.Logf("Host added port %d", echoPort) + + // 4. Client connects to the same tunnel. + // Re-fetch the tunnel to get updated endpoints and connect token. + connectOptions := &tunnels.TunnelRequestOptions{ + TokenScopes: []tunnels.TunnelAccessScope{tunnels.TunnelAccessScopeConnect}, + IncludePorts: true, + } + clientTunnel, err := mgr.GetTunnel(ctx, tunnel, connectOptions) + if err != nil { + t.Fatalf("GetTunnel (for client) failed: %v", err) + } + + client, err := tunnels.NewClient(logger, clientTunnel, true) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + if err := client.Connect(ctx, ""); err != nil { + t.Fatalf("Client.Connect failed: %v", err) + } + t.Log("Client connected") + + // 5. Client waits for the forwarded port. + if err := client.WaitForForwardedPort(ctx, echoPort); err != nil { + t.Fatalf("WaitForForwardedPort failed: %v", err) + } + t.Logf("Client received forwarded port %d", echoPort) + + // 6. Client opens a connection to the forwarded port. + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + defer listener.Close() + + if err := client.ConnectListenerToForwardedPort(ctx, listener.(*net.TCPListener), echoPort); err != nil { + t.Fatalf("ConnectListenerToForwardedPort failed: %v", err) + } + + // Connect to the local listener to trigger port forwarding. + localConn, err := net.DialTimeout("tcp", listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("failed to connect to local listener: %v", err) + } + defer localConn.Close() + + // 7. Send test data and verify echo. + testData := []byte("hello e2e tunnel") + _, err = localConn.Write(testData) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + buf := make([]byte, len(testData)) + localConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + _, err = io.ReadFull(localConn, buf) + if err != nil { + t.Fatalf("failed to read echo: %v", err) + } + + if string(buf) != string(testData) { + t.Fatalf("data mismatch: sent %q, received %q", testData, buf) + } + + t.Log("E2E tunnel data integrity verified") +} diff --git a/go/tunnels/host.go b/go/tunnels/host.go new file mode 100644 index 00000000..57bcaa97 --- /dev/null +++ b/go/tunnels/host.go @@ -0,0 +1,496 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "sync" + + tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh" + "golang.org/x/crypto/ssh" +) + +const ( + hostWebSocketSubProtocol = "tunnel-relay-host" + hostWebSocketSubProtocolV2 = "tunnel-relay-host-v2-dev" +) + +var ( + // ErrNoManager is returned when no manager is provided. + ErrNoManager = errors.New("manager cannot be nil") + + // ErrNoHostRelayURI is returned when the endpoint has no host relay URI. + ErrNoHostRelayURI = errors.New("endpoint host relay URI is empty") + + // ErrAlreadyConnected is returned when the host is already connected. + ErrAlreadyConnected = errors.New("host is already connected") + + // ErrPortAlreadyAdded is returned when the port is already forwarded. + ErrPortAlreadyAdded = errors.New("port is already added") + + // ErrNotConnected is returned when the host is not connected. + ErrNotConnected = errors.New("host is not connected") + + // ErrTooManyConnections is returned when the relay rejected the host + // because another host is already connected to this tunnel. + ErrTooManyConnections = errors.New("too many connections to tunnel") +) + +// Host is a host for a tunnel. It is used to host a tunnel and forward +// local TCP ports to remote clients through the relay. +// +// Locking strategy: mu guards ssh, tunnel, connectionStatus, disconnectReason, +// ctx, and cancel. All locking uses the snapshot-under-lock pattern: acquire +// lock, copy values to locals, release, then operate on the locals. +type Host struct { + logger *log.Logger + manager *Manager + + // mu guards ssh, tunnel, connectionStatus, disconnectReason, ctx, and cancel. + mu sync.Mutex + tunnel *Tunnel + ssh *tunnelssh.HostSSHSession + connectionStatus ConnectionStatus + disconnectReason uint32 + + hostID string + endpointID string + relayURI string + + hostKey ssh.Signer + + // EnableReconnect enables automatic reconnection when the relay + // connection drops. Default is false for backward compatibility. + EnableReconnect bool + + // ConnectionStatusChanged is called when the connection status changes. + // Both the previous and current status are provided. + ConnectionStatusChanged func(prev, curr ConnectionStatus) + + // RefreshTunnelAccessTokenFunc is called to obtain a fresh access token + // when the current one expires (HTTP 401). If nil, the host falls back + // to re-fetching the tunnel from the management service. + RefreshTunnelAccessTokenFunc func(ctx context.Context) (string, error) + + // ctx and cancel are created in Connect and cancelled in Close + // to stop reconnection loops. + ctx context.Context + cancel context.CancelFunc +} + +// NewHost creates a new Host instance. +func NewHost(logger *log.Logger, manager *Manager) (*Host, error) { + if manager == nil { + return nil, ErrNoManager + } + + if logger == nil { + logger = log.New(io.Discard, "", 0) + } + + hostID, err := generateUUID() + if err != nil { + return nil, fmt.Errorf("error generating host ID: %w", err) + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("error generating host key: %w", err) + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + return nil, fmt.Errorf("error creating SSH signer: %w", err) + } + + h := &Host{ + logger: logger, + manager: manager, + hostID: hostID, + endpointID: fmt.Sprintf("%s-relay", hostID), + hostKey: signer, + } + return h, nil +} + +// HostPublicKeyBase64 returns the base64-encoded public key of the host. +func (h *Host) HostPublicKeyBase64() string { + return base64.StdEncoding.EncodeToString(h.hostKey.PublicKey().Marshal()) +} + +// ConnectionStatus returns the current connection status. +func (h *Host) ConnectionStatus() ConnectionStatus { + h.mu.Lock() + defer h.mu.Unlock() + return h.connectionStatus +} + +// setConnectionStatus updates the connection status and invokes the callback. +func (h *Host) setConnectionStatus(status ConnectionStatus) { + h.mu.Lock() + prev := h.connectionStatus + h.connectionStatus = status + cb := h.ConnectionStatusChanged + h.mu.Unlock() + + if cb != nil && prev != status { + cb(prev, status) + } +} + +// Connect connects the host to a tunnel relay. +func (h *Host) Connect(ctx context.Context, tunnel *Tunnel) error { + if ctx == nil { + ctx = context.Background() + } + + h.mu.Lock() + if h.ssh != nil { + h.mu.Unlock() + return ErrAlreadyConnected + } + if h.disconnectReason == tunnelssh.SshDisconnectReasonTooManyConnections { + h.mu.Unlock() + return ErrTooManyConnections + } + h.tunnel = tunnel + h.mu.Unlock() + + connCtx, connCancel := context.WithCancel(ctx) + h.mu.Lock() + h.ctx = connCtx + h.cancel = connCancel + h.mu.Unlock() + + return h.connectOnce(ctx, tunnel) +} + +// connectOnce performs a single connection attempt: endpoint registration, +// WebSocket connect, and SSH session setup. +func (h *Host) connectOnce(ctx context.Context, tunnel *Tunnel) error { + h.setConnectionStatus(ConnectionStatusConnecting) + + // Check if any port uses the "ssh" protocol for the gateway key query param. + var opts *TunnelRequestOptions + for _, p := range tunnel.Ports { + if p.Protocol == "ssh" { + opts = &TunnelRequestOptions{ + AdditionalQueryParameters: map[string]string{ + "includeSshGatewayPublicKey": "true", + }, + } + break + } + } + + // Register the endpoint with the management API. + endpoint := &TunnelEndpoint{ + ID: h.endpointID, + HostID: h.hostID, + ConnectionMode: TunnelConnectionModeTunnelRelay, + HostPublicKeys: []string{h.HostPublicKeyBase64()}, + } + + endpointResult, err := h.manager.UpdateTunnelEndpoint(ctx, tunnel, endpoint, nil, opts) + if err != nil { + h.setConnectionStatus(ConnectionStatusDisconnected) + return fmt.Errorf("error updating tunnel endpoint: %w", err) + } + + if endpointResult.HostRelayURI == "" { + h.setConnectionStatus(ConnectionStatusDisconnected) + return ErrNoHostRelayURI + } + h.relayURI = endpointResult.HostRelayURI + + // Extract host access token, guarding against nil map. + var accessToken string + if tunnel.AccessTokens != nil { + accessToken = tunnel.AccessTokens[TunnelAccessScopeHost] + } + + h.logger.Printf("Connecting to host tunnel relay %s", h.relayURI) + protocols := []string{hostWebSocketSubProtocolV2, hostWebSocketSubProtocol} + + var headers http.Header + if accessToken != "" { + headers = make(http.Header) + if !strings.HasPrefix(accessToken, "Tunnel ") { + accessToken = fmt.Sprintf("Tunnel %s", accessToken) + } + headers.Add("Authorization", accessToken) + } + + sock := newSocket(h.relayURI, protocols, headers, nil) + if err := sock.connect(ctx); err != nil { + h.setConnectionStatus(ConnectionStatusDisconnected) + return fmt.Errorf("error connecting to host relay: %w", err) + } + + negotiatedProtocol := sock.Subprotocol() + h.logger.Printf("Negotiated subprotocol: %s", negotiatedProtocol) + + // In V1, the relay does not handle tcpip-forward; pass empty token. + sshAccessToken := accessToken + if negotiatedProtocol == hostWebSocketSubProtocol { + sshAccessToken = "" + } + + sshSession := tunnelssh.NewHostSSHSession(sock, h.hostKey, h.logger, sshAccessToken, negotiatedProtocol) + if err := sshSession.Connect(ctx); err != nil { + sock.Close() + h.setConnectionStatus(ConnectionStatusDisconnected) + return fmt.Errorf("error establishing SSH session: %w", err) + } + + h.mu.Lock() + h.ssh = sshSession + h.mu.Unlock() + + h.setConnectionStatus(ConnectionStatusConnected) + return nil +} + +// Close gracefully shuts down the host connection. +// It closes the SSH session and unregisters the endpoint. +// Close is idempotent — calling it twice does not panic or error. +// Returns ErrNotConnected if the host was never connected. +func (h *Host) Close() error { + h.mu.Lock() + sshSession := h.ssh + tunnel := h.tunnel + cancel := h.cancel + h.ssh = nil + h.mu.Unlock() + + // Cancel any reconnection loop. + if cancel != nil { + cancel() + } + + if sshSession == nil { + // If tunnel is set, we were connected before — this is an idempotent close. + if tunnel != nil { + return nil + } + return ErrNotConnected + } + + h.setConnectionStatus(ConnectionStatusDisconnected) + + sshSession.Close() + + // Unregister the endpoint unless the relay disconnected us for + // TooManyConnections (another host is authoritative). + if tunnel != nil && sshSession.DisconnectReason() != tunnelssh.SshDisconnectReasonTooManyConnections { + ctx := context.Background() + if err := h.manager.DeleteTunnelEndpoints(ctx, tunnel, h.endpointID, nil); err != nil { + h.logger.Printf("error deleting tunnel endpoint: %v", err) + } + } + + return nil +} + +// Wait blocks until the relay connection drops. +// If EnableReconnect is true, Wait will attempt to reconnect with +// exponential backoff before returning. +func (h *Host) Wait() error { + h.mu.Lock() + sshSession := h.ssh + h.mu.Unlock() + + if sshSession == nil { + return ErrNotConnected + } + + for { + err := sshSession.Wait() + + h.mu.Lock() + h.disconnectReason = sshSession.DisconnectReason() + disconnectReason := h.disconnectReason + reconnect := h.EnableReconnect + connCtx := h.ctx + h.mu.Unlock() + + h.setConnectionStatus(ConnectionStatusDisconnected) + + if !reconnect { + return err + } + + if disconnectReason == tunnelssh.SshDisconnectReasonTooManyConnections { + return ErrTooManyConnections + } + + if reconnectErr := h.reconnect(connCtx); reconnectErr != nil { + return reconnectErr + } + + // Reconnected — read new session and wait again. + h.mu.Lock() + sshSession = h.ssh + h.mu.Unlock() + + if sshSession == nil { + return ErrNotConnected + } + } +} + +// AddPort registers a port with the management API and notifies connected clients. +func (h *Host) AddPort(ctx context.Context, port *TunnelPort) error { + h.mu.Lock() + sshSession := h.ssh + tunnel := h.tunnel + h.mu.Unlock() + + if sshSession == nil { + return ErrNotConnected + } + + if sshSession.HasPort(port.PortNumber) { + return ErrPortAlreadyAdded + } + + // Register the port with the management API. + _, err := h.manager.CreateTunnelPort(ctx, tunnel, port, nil) + if err != nil { + // Tolerate 409 Conflict (port already exists on the service). + var tunnelErr *TunnelError + if !errors.As(err, &tunnelErr) || tunnelErr.StatusCode != http.StatusConflict { + return fmt.Errorf("error creating tunnel port: %w", err) + } + } + + // Extract access token for the relay request. + // V1 does not send tokens to the relay; V2 requires them. + var accessToken string + if sshSession.ConnectionProtocol() != tunnelssh.HostWebSocketSubProtocol { + if tunnel.AccessTokens != nil { + accessToken = tunnel.AccessTokens[TunnelAccessScopeHost] + } + } + + // Add to SSH session and send tcpip-forward to the relay (V2) or clients (V1). + sshSession.AddPort(port.PortNumber, accessToken) + + return nil +} + +// RemovePort removes a forwarded port and notifies connected clients. +func (h *Host) RemovePort(ctx context.Context, portNumber uint16) error { + h.mu.Lock() + sshSession := h.ssh + tunnel := h.tunnel + h.mu.Unlock() + + if sshSession == nil { + return ErrNotConnected + } + + // Unregister from the management API. Errors are logged but not returned. + if err := h.manager.DeleteTunnelPort(ctx, tunnel, portNumber, nil); err != nil { + h.logger.Printf("error deleting tunnel port %d: %v", portNumber, err) + } + + // Extract access token for the relay request. + // V1 does not send tokens to the relay; V2 requires them. + var accessToken string + if sshSession.ConnectionProtocol() != tunnelssh.HostWebSocketSubProtocol { + if tunnel.AccessTokens != nil { + accessToken = tunnel.AccessTokens[TunnelAccessScopeHost] + } + } + + sshSession.RemovePort(portNumber, accessToken) + + return nil +} + +// RefreshPorts synchronizes the local forwarded ports with the tunnel service. +// New ports on the service are added, and stale local ports are removed. +func (h *Host) RefreshPorts(ctx context.Context) error { + h.mu.Lock() + sshSession := h.ssh + tunnel := h.tunnel + h.mu.Unlock() + + if sshSession == nil { + return ErrNotConnected + } + + // Fetch tunnel with ports from the service. + opts := &TunnelRequestOptions{IncludePorts: true} + refreshed, err := h.manager.GetTunnel(ctx, tunnel, opts) + if err != nil { + return fmt.Errorf("error fetching tunnel for port refresh: %w", err) + } + + // Build a set of remote port numbers. + remotePorts := make(map[uint16]struct{}, len(refreshed.Ports)) + for _, p := range refreshed.Ports { + remotePorts[p.PortNumber] = struct{}{} + } + + // Get current local ports from the SSH session (single source of truth). + localPorts := sshSession.Ports() + localSet := make(map[uint16]struct{}, len(localPorts)) + for _, pn := range localPorts { + localSet[pn] = struct{}{} + } + + // Extract access token for the relay requests. + // V1 does not send tokens to the relay; V2 requires them. + var accessToken string + if sshSession.ConnectionProtocol() != tunnelssh.HostWebSocketSubProtocol { + if tunnel.AccessTokens != nil { + accessToken = tunnel.AccessTokens[TunnelAccessScopeHost] + } + } + + // Add ports that are on the service but not local. + for pn := range remotePorts { + if _, exists := localSet[pn]; !exists { + sshSession.AddPort(pn, accessToken) + } + } + + // Remove ports that are local but not on the service. + for _, pn := range localPorts { + if _, exists := remotePorts[pn]; !exists { + sshSession.RemovePort(pn, accessToken) + } + } + + // Update tunnel reference with refreshed ports. + h.mu.Lock() + h.tunnel.Ports = refreshed.Ports + h.mu.Unlock() + + return nil +} + +// generateUUID generates a UUID v4 string using crypto/rand. +func generateUUID() (string, error) { + var uuid [16]byte + if _, err := io.ReadFull(rand.Reader, uuid[:]); err != nil { + return "", err + } + // Set version (4) and variant (RFC 4122) + uuid[6] = (uuid[6] & 0x0f) | 0x40 + uuid[8] = (uuid[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16]), nil +} diff --git a/go/tunnels/host_e2e_test.go b/go/tunnels/host_e2e_test.go new file mode 100644 index 00000000..ce3a9f8e --- /dev/null +++ b/go/tunnels/host_e2e_test.go @@ -0,0 +1,2148 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh" + tunnelstest "github.com/microsoft/dev-tunnels/go/tunnels/test" +) + +// e2eMockAPI provides a mock tunnel management API backed by httptest.Server. +// It tracks calls to key endpoints via atomic counters and allows dynamic +// control of relay URI, remote ports, and 401 simulation. +type e2eMockAPI struct { + server *httptest.Server + manager *Manager + deleteEndpointCalls int32 // atomic + createPortCalls int32 // atomic + portConflictOnce int32 // atomic flag for one-shot 409 simulation on port creation + remotePorts atomic.Value // stores []TunnelPort + relayURI atomic.Value // stores string + unauthorizedOnce int32 // atomic flag for one-shot 401 simulation +} + +// newE2EMockAPI creates an e2eMockAPI with its httptest.Server and a Manager +// configured to use it. The relayURI is set to the given initial value. +func newE2EMockAPI(t *testing.T, initialRelayURI string) *e2eMockAPI { + t.Helper() + + api := &e2eMockAPI{} + api.relayURI.Store(initialRelayURI) + api.remotePorts.Store([]TunnelPort{}) + + mux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for one-shot 401 simulation on UpdateTunnelEndpoint. + if r.Method == http.MethodPut && containsSegment(r.URL.Path, "endpoints") { + if atomic.CompareAndSwapInt32(&api.unauthorizedOnce, 1, 0) { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"detail": "token expired"}) + return + } + uri := api.relayURI.Load().(string) + endpoint := TunnelEndpoint{ + ID: "test-endpoint", + TunnelRelayTunnelEndpoint: TunnelRelayTunnelEndpoint{ + HostRelayURI: uri, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(endpoint) + return + } + + // DeleteTunnelEndpoints + if r.Method == http.MethodDelete && containsSegment(r.URL.Path, "endpoints") { + atomic.AddInt32(&api.deleteEndpointCalls, 1) + w.WriteHeader(http.StatusOK) + return + } + + // GetTunnel (GET .../tunnels/... but not .../ports/...) + if r.Method == http.MethodGet && containsSegment(r.URL.Path, "tunnels") && !containsSegment(r.URL.Path, "ports") { + ports := api.remotePorts.Load().([]TunnelPort) + tunnel := Tunnel{ + Name: "test-tunnel", + Ports: ports, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tunnel) + return + } + + // CreateTunnelPort (PUT .../ports/...) + if r.Method == http.MethodPut && containsSegment(r.URL.Path, "ports") { + if atomic.CompareAndSwapInt32(&api.portConflictOnce, 1, 0) { + w.WriteHeader(http.StatusConflict) + json.NewEncoder(w).Encode(map[string]string{"detail": "port already exists"}) + return + } + atomic.AddInt32(&api.createPortCalls, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TunnelPort{PortNumber: 0}) + return + } + + // DeleteTunnelPort + if r.Method == http.MethodDelete && containsSegment(r.URL.Path, "ports") { + w.WriteHeader(http.StatusOK) + return + } + + // GetTunnelPorts (GET .../ports/...) + if r.Method == http.MethodGet && containsSegment(r.URL.Path, "ports") { + ports := api.remotePorts.Load().([]TunnelPort) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ports) + return + } + + w.WriteHeader(http.StatusNotFound) + }) + + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + serviceURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse mock server URL: %v", err) + } + + mgr := &Manager{ + tokenProvider: func() string { return "" }, + httpClient: &http.Client{}, + uri: serviceURL, + userAgents: []UserAgent{{Name: "test", Version: "1.0"}}, + apiVersion: "2023-09-27-preview", + } + + api.server = server + api.manager = mgr + return api +} + +// containsSegment checks if a URL path contains a given segment delimited by slashes. +func containsSegment(path, segment string) bool { + // Simple substring check — sufficient for test routing. + return len(path) > 0 && findSubstring(path, "/"+segment+"/") || hasSuffix(path, "/"+segment) +} + +func findSubstring(s, sub string) bool { + return len(sub) <= len(s) && containsStr(s, sub) +} + +func containsStr(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} + +func hasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} + +// startEchoServerE2E starts a TCP echo server on 127.0.0.1:0. +// It registers cleanup to close the listener when the test ends. +// Returns the listener and port number. +func startEchoServerE2E(t *testing.T) (net.Listener, uint16) { + t.Helper() + + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + t.Cleanup(func() { listener.Close() }) + port := uint16(listener.Addr().(*net.TCPAddr).Port) + return listener, port +} + +func TestE2E_HostLifecycle(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API pointing to relay. + api := newE2EMockAPI(t, relay.URL()) + + logger := log.New(os.Stderr, "e2e-lifecycle: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Assert initial status is None. + if status := host.ConnectionStatus(); status != ConnectionStatusNone { + t.Fatalf("expected ConnectionStatusNone, got %v", status) + } + + // Assert HostPublicKeyBase64 is non-empty and base64-decodable. + pubKeyB64 := host.HostPublicKeyBase64() + if pubKeyB64 == "" { + t.Fatal("HostPublicKeyBase64 returned empty string") + } + if _, err := base64.StdEncoding.DecodeString(pubKeyB64); err != nil { + t.Fatalf("HostPublicKeyBase64 is not valid base64: %v", err) + } + + // Connect to the relay. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + // Wait for relay to confirm connection. + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Assert connected status. + if status := host.ConnectionStatus(); status != ConnectionStatusConnected { + t.Fatalf("expected ConnectionStatusConnected, got %v", status) + } + + // Close the host. + if err := host.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + // Assert disconnected status. + if status := host.ConnectionStatus(); status != ConnectionStatusDisconnected { + t.Fatalf("expected ConnectionStatusDisconnected, got %v", status) + } + + // Assert endpoint was deleted exactly once. + deleteCalls := atomic.LoadInt32(&api.deleteEndpointCalls) + if deleteCalls != 1 { + t.Fatalf("expected 1 deleteEndpointCalls, got %d", deleteCalls) + } + + // Second Close should be idempotent (return nil). + if err := host.Close(); err != nil { + t.Fatalf("second Host.Close should be idempotent, got: %v", err) + } + + // deleteEndpointCalls should still be 1. + deleteCalls = atomic.LoadInt32(&api.deleteEndpointCalls) + if deleteCalls != 1 { + t.Fatalf("expected deleteEndpointCalls still 1 after idempotent close, got %d", deleteCalls) + } +} + +func TestE2E_ErrorHandling(t *testing.T) { + t.Run("ErrNoManager", func(t *testing.T) { + _, err := NewHost(nil, nil) + if !errors.Is(err, ErrNoManager) { + t.Fatalf("expected ErrNoManager, got %v", err) + } + }) + + t.Run("ErrAlreadyConnected", func(t *testing.T) { + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + api := newE2EMockAPI(t, relay.URL()) + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("first Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + err = host.Connect(ctx, tunnel) + if !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("expected ErrAlreadyConnected, got %v", err) + } + }) + + t.Run("ErrNotConnected_Close", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + err = host.Close() + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } + }) + + t.Run("ErrNotConnected_Wait", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + err = host.Wait() + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } + }) + + t.Run("ErrNotConnected_AddPort", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx := context.Background() + err = host.AddPort(ctx, &TunnelPort{PortNumber: 8080}) + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } + }) + + t.Run("ErrNotConnected_RemovePort", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx := context.Background() + err = host.RemovePort(ctx, 8080) + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } + }) + + t.Run("ErrNotConnected_RefreshPorts", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx := context.Background() + err = host.RefreshPorts(ctx) + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } + }) + + t.Run("ErrPortAlreadyAdded", func(t *testing.T) { + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + api := newE2EMockAPI(t, relay.URL()) + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + port := &TunnelPort{PortNumber: 8080} + if err := host.AddPort(ctx, port); err != nil { + t.Fatalf("first AddPort failed: %v", err) + } + + err = host.AddPort(ctx, port) + if !errors.Is(err, ErrPortAlreadyAdded) { + t.Fatalf("expected ErrPortAlreadyAdded, got %v", err) + } + }) + + t.Run("ErrTooManyConnections", func(t *testing.T) { + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Set disconnect reason to TooManyConnections (same package — can access internal fields). + host.mu.Lock() + host.disconnectReason = tunnelssh.SshDisconnectReasonTooManyConnections + host.mu.Unlock() + + ctx := context.Background() + tunnel := &Tunnel{Name: "test-tunnel"} + + err = host.Connect(ctx, tunnel) + if !errors.Is(err, ErrTooManyConnections) { + t.Fatalf("expected ErrTooManyConnections, got %v", err) + } + }) + + t.Run("ErrNoHostRelayURI", func(t *testing.T) { + // Mock API with empty relay URI -> UpdateTunnelEndpoint returns empty HostRelayURI. + api := newE2EMockAPI(t, "") + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + err = host.Connect(ctx, tunnel) + if !errors.Is(err, ErrNoHostRelayURI) { + t.Fatalf("expected ErrNoHostRelayURI, got %v", err) + } + }) + + t.Run("409Conflict_Tolerated", func(t *testing.T) { + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + api := newE2EMockAPI(t, relay.URL()) + host, err := NewHost(log.New(os.Stderr, "e2e-err: ", log.LstdFlags), api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Set flag so next port creation returns 409 Conflict. + atomic.StoreInt32(&api.portConflictOnce, 1) + + port := &TunnelPort{PortNumber: 9999} + if err := host.AddPort(ctx, port); err != nil { + t.Fatalf("AddPort should tolerate 409 Conflict, got: %v", err) + } + }) +} + +func TestE2E_PortForwardingDataFlow(t *testing.T) { + // Start TCP echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-dataflow: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port (host sends tcpip-forward to relay internally). + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process the tcpip-forward request. + time.Sleep(200 * time.Millisecond) + + // Simulate a client connection via the relay — opens a forwarded-tcpip channel directly. + clientConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection failed: %v", err) + } + + // Write test data through the tunnel (V2: net.Conn is the data stream). + testData := []byte("hello e2e tunnel") + if _, err := clientConn.Write(testData); err != nil { + t.Fatalf("failed to write through tunnel: %v", err) + } + + // Read echo response. + buf := make([]byte, len(testData)) + if _, err := io.ReadFull(clientConn, buf); err != nil { + t.Fatalf("failed to read echo response: %v", err) + } + + // Assert sent bytes == received bytes. + if string(buf) != string(testData) { + t.Fatalf("data integrity check failed: sent %q, received %q", testData, buf) + } + + clientConn.Close() +} + +func TestE2E_DirectTcpipAndForwardedTcpip(t *testing.T) { + // Start echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-channel-types: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process the tcpip-forward request. + time.Sleep(200 * time.Millisecond) + + // Test 1: forwarded-tcpip channel via SimulateClientConnection — echo "forwarded-test". + fwdConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection failed: %v", err) + } + + fwdMsg := []byte("forwarded-test") + if _, err := fwdConn.Write(fwdMsg); err != nil { + t.Fatalf("failed to write forwarded-tcpip: %v", err) + } + fwdBuf := make([]byte, len(fwdMsg)) + if _, err := io.ReadFull(fwdConn, fwdBuf); err != nil { + t.Fatalf("failed to read forwarded-tcpip echo: %v", err) + } + if string(fwdBuf) != string(fwdMsg) { + t.Fatalf("forwarded-tcpip echo mismatch: sent %q, got %q", fwdMsg, fwdBuf) + } + fwdConn.Close() + + // Test 2: second forwarded-tcpip channel — echo "direct-test". + // In V2, both forwarded-tcpip and direct-tcpip are handled via the relay. + // SimulateClientConnection uses forwarded-tcpip which covers the V2 data path. + directConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection (second) failed: %v", err) + } + + directMsg := []byte("direct-test") + if _, err := directConn.Write(directMsg); err != nil { + t.Fatalf("failed to write second channel: %v", err) + } + directBuf := make([]byte, len(directMsg)) + if _, err := io.ReadFull(directConn, directBuf); err != nil { + t.Fatalf("failed to read second channel echo: %v", err) + } + if string(directBuf) != string(directMsg) { + t.Fatalf("second channel echo mismatch: sent %q, got %q", directMsg, directBuf) + } + directConn.Close() + + // Test 3: connection to unregistered port — should be rejected. + unregPort := echoPort + 1000 + _, err = relay.SimulateClientConnection(unregPort) + if err == nil { + t.Fatal("expected error connecting to unregistered port, got nil") + } +} + +func TestE2E_MultiplePorts(t *testing.T) { + // Start 3 echo servers on different ports. + _, portA := startEchoServerE2E(t) + _, portB := startEchoServerE2E(t) + _, portC := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-multiport: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add all 3 ports. + for _, port := range []uint16{portA, portB, portC} { + if err := host.AddPort(ctx, &TunnelPort{PortNumber: port}); err != nil { + t.Fatalf("AddPort(%d) failed: %v", port, err) + } + } + + // Give the relay time to process all tcpip-forward requests. + time.Sleep(200 * time.Millisecond) + + // Verify relay registered all three ports. + for _, port := range []uint16{portA, portB, portC} { + if !relay.HasPort(port) { + t.Fatalf("relay did not register port %d", port) + } + } + + // Test each port with unique messages — no cross-contamination. + ports := []uint16{portA, portB, portC} + messages := []string{"port-A", "port-B", "port-C"} + + for i, port := range ports { + clientConn, err := relay.SimulateClientConnection(port) + if err != nil { + t.Fatalf("SimulateClientConnection(%d) failed: %v", port, err) + } + msg := []byte(messages[i]) + if _, err := clientConn.Write(msg); err != nil { + t.Fatalf("port %d: write failed: %v", port, err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(clientConn, buf); err != nil { + t.Fatalf("port %d: read failed: %v", port, err) + } + if string(buf) != string(msg) { + t.Fatalf("port %d: echo mismatch: sent %q, got %q", port, msg, buf) + } + clientConn.Close() + } +} + +func TestE2E_DynamicPortManagement(t *testing.T) { + // Start echo server A. + _, echoPortA := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-dynamic-ports: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add port A. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPortA}); err != nil { + t.Fatalf("AddPort(A) failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Verify data flows through port A. + connA, err := relay.SimulateClientConnection(echoPortA) + if err != nil { + t.Fatalf("SimulateClientConnection(A) failed: %v", err) + } + msgA := []byte("dynamic-port-A") + if _, err := connA.Write(msgA); err != nil { + t.Fatalf("port A: write failed: %v", err) + } + bufA := make([]byte, len(msgA)) + if _, err := io.ReadFull(connA, bufA); err != nil { + t.Fatalf("port A: read failed: %v", err) + } + if string(bufA) != string(msgA) { + t.Fatalf("port A: echo mismatch: sent %q, got %q", msgA, bufA) + } + connA.Close() + + // Start echo server B and add port B dynamically. + _, echoPortB := startEchoServerE2E(t) + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPortB}); err != nil { + t.Fatalf("AddPort(B) failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Verify data flows through port B. + connB, err := relay.SimulateClientConnection(echoPortB) + if err != nil { + t.Fatalf("SimulateClientConnection(B) failed: %v", err) + } + msgB := []byte("dynamic-port-B") + if _, err := connB.Write(msgB); err != nil { + t.Fatalf("port B: write failed: %v", err) + } + bufB := make([]byte, len(msgB)) + if _, err := io.ReadFull(connB, bufB); err != nil { + t.Fatalf("port B: read failed: %v", err) + } + if string(bufB) != string(msgB) { + t.Fatalf("port B: echo mismatch: sent %q, got %q", msgB, bufB) + } + connB.Close() + + // Remove port A. + if err := host.RemovePort(ctx, echoPortA); err != nil { + t.Fatalf("RemovePort(A) failed: %v", err) + } + + // Give the relay time to process cancel-tcpip-forward. + time.Sleep(200 * time.Millisecond) + + // Verify the relay no longer has port A. + if relay.HasPort(echoPortA) { + t.Fatal("relay should not have port A after RemovePort") + } + + // Attempt to connect to removed port A — expect rejection. + _, err = relay.SimulateClientConnection(echoPortA) + if err == nil { + t.Fatal("expected error connecting to removed port A") + } +} + +func TestE2E_PortDuplicateHandling(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-dup-port: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // First AddPort should succeed. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: 8080}); err != nil { + t.Fatalf("first AddPort failed: %v", err) + } + + // Second AddPort with same port should return ErrPortAlreadyAdded. + err = host.AddPort(ctx, &TunnelPort{PortNumber: 8080}) + if !errors.Is(err, ErrPortAlreadyAdded) { + t.Fatalf("expected ErrPortAlreadyAdded, got %v", err) + } + + // Access SSH session and verify port 8080 appears exactly once. + host.mu.Lock() + sshSession := host.ssh + host.mu.Unlock() + + ports := sshSession.Ports() + count := 0 + for _, p := range ports { + if p == 8080 { + count++ + } + } + if count != 1 { + t.Fatalf("expected port 8080 exactly once, found %d times in %v", count, ports) + } + + // Verify createPortCalls == 1 (only the first AddPort called the API). + createCalls := atomic.LoadInt32(&api.createPortCalls) + if createCalls != 1 { + t.Fatalf("expected 1 createPortCalls, got %d", createCalls) + } +} + +func TestE2E_RefreshPorts(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-refresh-ports: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Access SSH session directly (same package — internal fields accessible). + host.mu.Lock() + sshSession := host.ssh + host.mu.Unlock() + + // Manually add port 5000 to the SSH session. + sshSession.AddPort(5000, "test-token") + + // Configure mock API to return ports 3000 and 4000 as the remote set. + api.remotePorts.Store([]TunnelPort{ + {PortNumber: 3000}, + {PortNumber: 4000}, + }) + + // RefreshPorts should synchronize: add 3000, 4000; remove 5000. + if err := host.RefreshPorts(ctx); err != nil { + t.Fatalf("RefreshPorts failed: %v", err) + } + + // Assert port 3000 was added from service. + if !sshSession.HasPort(3000) { + t.Fatal("expected port 3000 to be present after RefreshPorts") + } + + // Assert port 4000 was added from service. + if !sshSession.HasPort(4000) { + t.Fatal("expected port 4000 to be present after RefreshPorts") + } + + // Assert port 5000 was removed (not on service). + if sshSession.HasPort(5000) { + t.Fatal("expected port 5000 to be removed after RefreshPorts") + } + + // Call RefreshPorts again — should be idempotent (no changes). + if err := host.RefreshPorts(ctx); err != nil { + t.Fatalf("idempotent RefreshPorts failed: %v", err) + } + + // Verify ports are still correct after idempotent call. + if !sshSession.HasPort(3000) { + t.Fatal("expected port 3000 still present after idempotent RefreshPorts") + } + if !sshSession.HasPort(4000) { + t.Fatal("expected port 4000 still present after idempotent RefreshPorts") + } + if sshSession.HasPort(5000) { + t.Fatal("port 5000 should still be absent after idempotent RefreshPorts") + } +} + +func TestE2E_LargeDataTransfer(t *testing.T) { + // Start echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-large-data: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Simulate a client connection via the relay. + clientConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection failed: %v", err) + } + + // Generate 1MB payload: payload[i] = byte(i % 256). + const payloadSize = 1048576 // 1MB + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i % 256) + } + + // Compute expected SHA256 hash. + expectedHash := sha256.Sum256(payload) + + // Goroutine: write payload, then close write side. + writeErr := make(chan error, 1) + go func() { + _, err := clientConn.Write(payload) + // Close write side to signal EOF to the echo server. + // channelNetConn wraps ssh.Channel which supports CloseWrite via Close. + writeErr <- err + }() + + // Main: read all echoed data. + var received bytes.Buffer + if _, err := io.CopyN(&received, clientConn, payloadSize); err != nil { + t.Fatalf("failed to read echo response: %v", err) + } + + // Check write error. + if err := <-writeErr; err != nil { + t.Fatalf("failed to write payload: %v", err) + } + + // Assert received length == 1MB. + if received.Len() != payloadSize { + t.Fatalf("expected %d bytes, received %d", payloadSize, received.Len()) + } + + // Compute actual SHA256 hash and compare. + actualHash := sha256.Sum256(received.Bytes()) + if expectedHash != actualHash { + t.Fatalf("SHA256 hash mismatch: expected %x, got %x", expectedHash, actualHash) + } + + clientConn.Close() +} + +func TestE2E_ConnectionStatusCallbacks(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API pointing to relay. + api := newE2EMockAPI(t, relay.URL()) + + logger := log.New(os.Stderr, "e2e-callbacks: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Assert initial status is None. + if status := host.ConnectionStatus(); status != ConnectionStatusNone { + t.Fatalf("expected initial ConnectionStatusNone, got %v", status) + } + + // Register callback that appends transitions to a mutex-guarded slice. + var mu sync.Mutex + var transitions []ConnectionStatus + host.ConnectionStatusChanged = func(prev, curr ConnectionStatus) { + mu.Lock() + transitions = append(transitions, curr) + mu.Unlock() + } + + // Connect to the relay. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + // Wait for relay to confirm connection. + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Close the host. + if err := host.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + // Assert transitions == [Connecting, Connected, Disconnected]. + mu.Lock() + got := make([]ConnectionStatus, len(transitions)) + copy(got, transitions) + mu.Unlock() + + expected := []ConnectionStatus{ + ConnectionStatusConnecting, + ConnectionStatusConnected, + ConnectionStatusDisconnected, + } + + if len(got) != len(expected) { + t.Fatalf("expected %d transitions, got %d: %v", len(expected), len(got), got) + } + + for i, exp := range expected { + if got[i] != exp { + t.Fatalf("transition[%d]: expected %v, got %v (full: %v)", i, exp, got[i], got) + } + } +} + +func TestE2E_BidirectionalStreaming(t *testing.T) { + // Start echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-bidir: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Simulate a client connection via the relay. + clientConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection failed: %v", err) + } + + // Send 10 messages of increasing size (100, 200, ..., 1000 bytes). + for i := 1; i <= 10; i++ { + size := i * 100 + msg := bytes.Repeat([]byte{byte(i)}, size) + + if _, err := clientConn.Write(msg); err != nil { + t.Fatalf("message %d: write failed: %v", i, err) + } + + buf := make([]byte, size) + if _, err := io.ReadFull(clientConn, buf); err != nil { + t.Fatalf("message %d: read failed: %v", i, err) + } + + if !bytes.Equal(buf, msg) { + t.Fatalf("message %d: echo mismatch: sent %d bytes of 0x%02x, got different content", i, size, byte(i)) + } + } + + clientConn.Close() +} + +func TestE2E_MultipleConcurrentClients(t *testing.T) { + // Start echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-concurrent-clients: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Spawn 5 goroutines, each simulating a client connection and verifying echo. + // In V2, each call to SimulateClientConnection opens a new forwarded-tcpip channel. + const numClients = 5 + var wg sync.WaitGroup + errs := make(chan error, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(clientID int) { + defer wg.Done() + + // Simulate a client connection via the relay. + clientConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + errs <- fmt.Errorf("client-%d: SimulateClientConnection failed: %v", clientID, err) + return + } + + // Write unique message and verify echo. + msg := []byte(fmt.Sprintf("client-%d", clientID)) + if _, err := clientConn.Write(msg); err != nil { + errs <- fmt.Errorf("client-%d: write failed: %v", clientID, err) + return + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(clientConn, buf); err != nil { + errs <- fmt.Errorf("client-%d: read failed: %v", clientID, err) + return + } + if string(buf) != string(msg) { + errs <- fmt.Errorf("client-%d: echo mismatch: sent %q, got %q", clientID, msg, buf) + return + } + clientConn.Close() + }(i) + } + + wg.Wait() + close(errs) + + // Collect all errors. + for err := range errs { + t.Error(err) + } +} + +func TestE2E_ConcurrentPortOperations(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-concurrent-ports: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Access SSH session for direct port manipulation. + host.mu.Lock() + sshSession := host.ssh + host.mu.Unlock() + + // Launch 10 goroutines: + // 0-4: add ports 9000-9004 (these stay) + // 5-9: add then remove ports 9005-9009 (these are removed) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + port := uint16(9000 + idx) + sshSession.AddPort(port, "test-token") + if idx >= 5 { + sshSession.RemovePort(port, "test-token") + } + }(i) + } + + wg.Wait() + + // Verify ports 9000-9004 exist. + for i := 0; i < 5; i++ { + port := uint16(9000 + i) + if !sshSession.HasPort(port) { + t.Fatalf("expected port %d to exist", port) + } + } + + // Verify ports 9005-9009 do NOT exist. + for i := 5; i < 10; i++ { + port := uint16(9000 + i) + if sshSession.HasPort(port) { + t.Fatalf("expected port %d to not exist (was added then removed)", port) + } + } + + // Verify exactly 5 ports remain. + ports := sshSession.Ports() + if len(ports) != 5 { + t.Fatalf("expected 5 ports, got %d: %v", len(ports), ports) + } +} + +func TestE2E_IPv4AndIPv6(t *testing.T) { + // Start IPv4 echo server on 127.0.0.1:0. + _, ipv4Port := startEchoServerE2E(t) + + // Start IPv6 echo server on [::1]:0. + // Skip the test if IPv6 is not available on this machine. + ipv6Listener, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available") + } + ipv6Port := uint16(ipv6Listener.Addr().(*net.TCPAddr).Port) + go func() { + for { + conn, err := ipv6Listener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + t.Cleanup(func() { ipv6Listener.Close() }) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-ipv4v6: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add both ports. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: ipv4Port}); err != nil { + t.Fatalf("AddPort(ipv4) failed: %v", err) + } + if err := host.AddPort(ctx, &TunnelPort{PortNumber: ipv6Port}); err != nil { + t.Fatalf("AddPort(ipv6) failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Test IPv4: connect via relay, send "ipv4-test", verify echo. + connV4, err := relay.SimulateClientConnection(ipv4Port) + if err != nil { + t.Fatalf("SimulateClientConnection(ipv4) failed: %v", err) + } + msgV4 := []byte("ipv4-test") + if _, err := connV4.Write(msgV4); err != nil { + t.Fatalf("IPv4 write failed: %v", err) + } + bufV4 := make([]byte, len(msgV4)) + if _, err := io.ReadFull(connV4, bufV4); err != nil { + t.Fatalf("IPv4 read failed: %v", err) + } + if string(bufV4) != string(msgV4) { + t.Fatalf("IPv4 echo mismatch: sent %q, got %q", msgV4, bufV4) + } + connV4.Close() + + // Test IPv6: connect via relay, send "ipv6-test", verify echo. + connV6, err := relay.SimulateClientConnection(ipv6Port) + if err != nil { + t.Fatalf("SimulateClientConnection(ipv6) failed: %v", err) + } + msgV6 := []byte("ipv6-test") + if _, err := connV6.Write(msgV6); err != nil { + t.Fatalf("IPv6 write failed: %v", err) + } + bufV6 := make([]byte, len(msgV6)) + if _, err := io.ReadFull(connV6, bufV6); err != nil { + t.Fatalf("IPv6 read failed: %v", err) + } + if string(bufV6) != string(msgV6) { + t.Fatalf("IPv6 echo mismatch: sent %q, got %q", msgV6, bufV6) + } + connV6.Close() +} + +func TestE2E_ConnectionRefused(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-conn-refused: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add port 19999 — no listener on this port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: 19999}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Simulate a client connection to port 19999 (no listener). + // The host accepts the channel but the local dial fails, so it closes the channel. + clientConn, err := relay.SimulateClientConnection(19999) + if err != nil { + t.Fatalf("SimulateClientConnection failed: %v", err) + } + + // Read should return EOF or error (no panic). + buf := make([]byte, 64) + _, readErr := clientConn.Read(buf) + if readErr == nil { + t.Fatal("expected read error (EOF or other) on connection-refused port, got nil") + } + + clientConn.Close() + + // No panic has occurred — host should remain functional. + // Verify by starting an echo server, adding its port, and testing data flow. + _, echoPort := startEchoServerE2E(t) + + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort(echo) failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Verify data flows through the echo port — host is still functional. + echoConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection(echo) failed: %v", err) + } + msg := []byte("still-alive") + if _, err := echoConn.Write(msg); err != nil { + t.Fatalf("echo write failed: %v", err) + } + echoBuf := make([]byte, len(msg)) + if _, err := io.ReadFull(echoConn, echoBuf); err != nil { + t.Fatalf("echo read failed: %v", err) + } + if string(echoBuf) != string(msg) { + t.Fatalf("echo mismatch: sent %q, got %q", msg, echoBuf) + } + echoConn.Close() +} + +func TestE2E_ClientDisconnectMidTransfer(t *testing.T) { + // Start echo server. + _, echoPort := startEchoServerE2E(t) + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-client-disconnect: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Add the echo port. + if err := host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}); err != nil { + t.Fatalf("AddPort failed: %v", err) + } + + // Give the relay time to process. + time.Sleep(200 * time.Millisecond) + + // Simulate client 1 connection. + clientConn1, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection(1) failed: %v", err) + } + + // Write partial data to the channel. + if _, err := clientConn1.Write([]byte("partial data")); err != nil { + t.Fatalf("client 1: write failed: %v", err) + } + + // Immediately close client 1 connection (abrupt disconnect). + clientConn1.Close() + + // No panic after 200ms sleep. + time.Sleep(200 * time.Millisecond) + + // Simulate client 2 — host should still be functional. + clientConn2, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("SimulateClientConnection(2) failed: %v", err) + } + + // Write "full message" and verify echo response. + msg := []byte("full message") + if _, err := clientConn2.Write(msg); err != nil { + t.Fatalf("client 2: write failed: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(clientConn2, buf); err != nil { + t.Fatalf("client 2: read failed: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("client 2: echo mismatch: sent %q, got %q", msg, buf) + } + clientConn2.Close() +} + +func TestE2E_WaitBlocksUntilDisconnect(t *testing.T) { + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-wait: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Launch host.Wait() in a goroutine and send the result to a channel. + waitResult := make(chan error, 1) + go func() { + waitResult <- host.Wait() + }() + + // After 200ms, verify the channel is empty (Wait is still blocking). + time.Sleep(200 * time.Millisecond) + select { + case err := <-waitResult: + t.Fatalf("Wait returned prematurely with error: %v", err) + default: + // Good — Wait is still blocking. + } + + // Close the relay server to simulate relay drop. + relay.Close() + + // Wait for result from channel with 5s timeout. + select { + case err := <-waitResult: + // Wait should return a non-nil error when the relay drops. + if err == nil { + t.Fatal("expected non-nil error from Wait after relay close, got nil") + } + case <-time.After(5 * time.Second): + t.Fatal("Wait did not return within 5s after relay close") + } +} + +func TestE2E_Reconnection(t *testing.T) { + // Create relay 1 with access token. + relay1, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay 1: %v", err) + } + + // Create mock management API initially pointing to relay 1. + api := newE2EMockAPI(t, relay1.URL()) + logger := log.New(os.Stderr, "e2e-reconnect: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Enable reconnection. + host.EnableReconnect = true + + // Track status transitions via callback. + var mu sync.Mutex + var transitions []ConnectionStatus + host.ConnectionStatusChanged = func(prev, curr ConnectionStatus) { + mu.Lock() + transitions = append(transitions, curr) + mu.Unlock() + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + // Connect to relay 1. + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relay1.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay 1 did not receive connection: %v", err) + } + + // Verify connected. + if status := host.ConnectionStatus(); status != ConnectionStatusConnected { + t.Fatalf("expected Connected, got %v", status) + } + + // Launch Wait() in a goroutine — it will handle reconnection. + waitResult := make(chan error, 1) + go func() { + waitResult <- host.Wait() + }() + + // Create relay 2. + relay2, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay 2: %v", err) + } + t.Cleanup(func() { relay2.Close() }) + + // Update API to point to relay 2. + api.relayURI.Store(relay2.URL()) + + // Close relay 1 to simulate disconnect. + relay1.Close() + + // Wait for relay 2 to receive the reconnection (30s timeout). + if err := relay2.WaitForConnection(30 * time.Second); err != nil { + t.Fatalf("relay 2 did not receive reconnection: %v", err) + } + + // Wait for host status to settle to Connected (relay connects before host finishes connectOnce). + deadline := time.After(5 * time.Second) + for host.ConnectionStatus() != ConnectionStatusConnected { + select { + case <-deadline: + t.Fatalf("expected Connected after reconnection, got %v", host.ConnectionStatus()) + case <-time.After(10 * time.Millisecond): + } + } + + // Verify status transitions include: Connected -> Disconnected -> Connecting -> Connected. + mu.Lock() + got := make([]ConnectionStatus, len(transitions)) + copy(got, transitions) + mu.Unlock() + + expectedSubseq := []ConnectionStatus{ + ConnectionStatusConnected, + ConnectionStatusDisconnected, + ConnectionStatusConnecting, + ConnectionStatusConnected, + } + + found := false + for i := 0; i <= len(got)-len(expectedSubseq); i++ { + match := true + for j, exp := range expectedSubseq { + if got[i+j] != exp { + match = false + break + } + } + if match { + found = true + break + } + } + if !found { + t.Fatalf("expected transitions to include %v, got %v", expectedSubseq, got) + } + + // Close host. + if err := host.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + // Wait for Wait() to return. + select { + case <-waitResult: + // Wait returned — good. + case <-time.After(5 * time.Second): + t.Fatal("Wait did not return within 5s after host close") + } +} + +func TestE2E_HostPublicKeyAvailable(t *testing.T) { + // In V2 there is no nested SSH server exposed to clients, so we cannot + // capture the host key via an SSH handshake. Instead, verify that the + // host exposes a valid base64-encoded public key that clients can + // retrieve via the management API (HostPublicKeys field on the endpoint). + + // Create relay server with access token. + relay, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + t.Cleanup(func() { relay.Close() }) + + // Create mock management API and host. + api := newE2EMockAPI(t, relay.URL()) + logger := log.New(os.Stderr, "e2e-pubkey: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Get the host public key (base64-encoded). + expectedKey := host.HostPublicKeyBase64() + if expectedKey == "" { + t.Fatal("HostPublicKeyBase64 returned empty string") + } + + // Verify the key is valid base64. + keyBytes, err := base64.StdEncoding.DecodeString(expectedKey) + if err != nil { + t.Fatalf("HostPublicKeyBase64 is not valid base64: %v", err) + } + + // Verify the key bytes are non-empty and have a reasonable length. + if len(keyBytes) < 16 { + t.Fatalf("decoded public key is too short: %d bytes", len(keyBytes)) + } + + // Connect and verify the key remains stable after connection. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer host.Close() + + if err := relay.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Key should be the same after connection. + afterConnectKey := host.HostPublicKeyBase64() + if afterConnectKey != expectedKey { + t.Fatalf("public key changed after connect:\n before: %s\n after: %s", expectedKey, afterConnectKey) + } +} + +func TestE2E_TokenRefresh(t *testing.T) { + // Create relay 1 with access token. + relay1, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay 1: %v", err) + } + + // Create mock management API initially pointing to relay 1. + api := newE2EMockAPI(t, relay1.URL()) + logger := log.New(os.Stderr, "e2e-token-refresh: ", log.LstdFlags) + host, err := NewHost(logger, api.manager) + if err != nil { + t.Fatalf("NewHost failed: %v", err) + } + + // Enable reconnection. + host.EnableReconnect = true + + // Set up token refresh callback that increments an atomic counter. + var tokenRefreshCalls int32 + host.RefreshTunnelAccessTokenFunc = func(ctx context.Context) (string, error) { + atomic.AddInt32(&tokenRefreshCalls, 1) + return "refreshed-token", nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } + + // Connect to relay 1. + if err := host.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relay1.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay 1 did not receive connection: %v", err) + } + + // Verify connected. + if status := host.ConnectionStatus(); status != ConnectionStatusConnected { + t.Fatalf("expected Connected, got %v", status) + } + + // Launch Wait() in a goroutine — it will handle reconnection. + waitResult := make(chan error, 1) + go func() { + waitResult <- host.Wait() + }() + + // Set unauthorizedOnce so next UpdateTunnelEndpoint returns 401. + atomic.StoreInt32(&api.unauthorizedOnce, 1) + + // Create relay 2. + relay2, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel refreshed-token"), + ) + if err != nil { + t.Fatalf("failed to create relay 2: %v", err) + } + t.Cleanup(func() { relay2.Close() }) + + // Update API to point to relay 2. + api.relayURI.Store(relay2.URL()) + + // Close relay 1 to trigger reconnection. + relay1.Close() + + // Wait for relay 2 to receive the reconnection (30s timeout). + if err := relay2.WaitForConnection(30 * time.Second); err != nil { + t.Fatalf("relay 2 did not receive reconnection: %v", err) + } + + // Wait for host status to settle to Connected. + deadline := time.After(5 * time.Second) + for host.ConnectionStatus() != ConnectionStatusConnected { + select { + case <-deadline: + t.Fatalf("expected Connected after reconnection, got %v", host.ConnectionStatus()) + case <-time.After(10 * time.Millisecond): + } + } + + // Assert token refresh was called at least once. + calls := atomic.LoadInt32(&tokenRefreshCalls) + if calls < 1 { + t.Fatalf("expected tokenRefreshCalls >= 1, got %d", calls) + } + + // Close host. + if err := host.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + // Wait for Wait() to return. + select { + case <-waitResult: + // Wait returned — good. + case <-time.After(5 * time.Second): + t.Fatal("Wait did not return within 5s after host close") + } +} diff --git a/go/tunnels/host_reconnect.go b/go/tunnels/host_reconnect.go new file mode 100644 index 00000000..e1d7bbd7 --- /dev/null +++ b/go/tunnels/host_reconnect.go @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" +) + +const ( + reconnectBaseDelay = 100 * time.Millisecond + reconnectMaxDelay = 12800 * time.Millisecond +) + +// reconnect attempts to re-establish the relay connection with exponential backoff. +func (h *Host) reconnect(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + delay := reconnectBaseDelay + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + h.logger.Printf("reconnecting in %v", delay) + + select { + case <-time.After(delay): + case <-ctx.Done(): + return ctx.Err() + } + + h.mu.Lock() + tunnel := h.tunnel + h.mu.Unlock() + + err := h.connectOnce(ctx, tunnel) + if err == nil { + h.logger.Printf("reconnected to relay") + return nil + } + + // On 401 Unauthorized, try refreshing the access token. + var tunnelErr *TunnelError + if errors.As(err, &tunnelErr) && tunnelErr.StatusCode == http.StatusUnauthorized { + if refreshErr := h.refreshAccessToken(ctx); refreshErr != nil { + h.logger.Printf("error refreshing access token: %v", refreshErr) + } else { + // Token refreshed — retry immediately with the same delay. + continue + } + } + + h.logger.Printf("error reconnecting: %v", err) + + // Exponential backoff. + delay *= 2 + if delay > reconnectMaxDelay { + delay = reconnectMaxDelay + } + } +} + +// refreshAccessToken attempts to refresh the tunnel access token using the +// callback, or by re-fetching the tunnel from the management service. +func (h *Host) refreshAccessToken(ctx context.Context) error { + h.mu.Lock() + tunnel := h.tunnel + cb := h.RefreshTunnelAccessTokenFunc + h.mu.Unlock() + + if cb != nil { + token, err := cb(ctx) + if err != nil { + return err + } + h.mu.Lock() + if tunnel.AccessTokens == nil { + tunnel.AccessTokens = make(map[TunnelAccessScope]string) + } + tunnel.AccessTokens[TunnelAccessScopeHost] = token + h.mu.Unlock() + return nil + } + + // Fallback: re-fetch the tunnel from the management service. + opts := &TunnelRequestOptions{ + TokenScopes: TunnelAccessScopes{TunnelAccessScopeHost}, + } + refreshed, err := h.manager.GetTunnel(ctx, tunnel, opts) + if err != nil { + return fmt.Errorf("error refreshing tunnel: %w", err) + } + + h.mu.Lock() + if refreshed.AccessTokens != nil { + if tunnel.AccessTokens == nil { + tunnel.AccessTokens = make(map[TunnelAccessScope]string) + } + for scope, token := range refreshed.AccessTokens { + tunnel.AccessTokens[scope] = token + } + } + h.mu.Unlock() + return nil +} diff --git a/go/tunnels/host_test.go b/go/tunnels/host_test.go new file mode 100644 index 00000000..d45a375a --- /dev/null +++ b/go/tunnels/host_test.go @@ -0,0 +1,946 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "errors" + "io" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "sync" + "testing" + "time" + + tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh" + tunnelstest "github.com/microsoft/dev-tunnels/go/tunnels/test" + "golang.org/x/crypto/ssh" +) + +func newTestManager() *Manager { + return &Manager{ + tokenProvider: func() string { return "" }, + } +} + +func TestNewHostReturnsErrWhenManagerIsNil(t *testing.T) { + _, err := NewHost(nil, nil) + if !errors.Is(err, ErrNoManager) { + t.Fatalf("expected ErrNoManager, got %v", err) + } +} + +func TestNewHostGeneratesUniqueHostID(t *testing.T) { + mgr := newTestManager() + h1, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + h2, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h1.hostID == h2.hostID { + t.Fatalf("expected different hostIDs, got %s and %s", h1.hostID, h2.hostID) + } +} + +func TestNewHostGeneratesValidHostKey(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.hostKey == nil { + t.Fatal("expected non-nil host key") + } + + // Verify the key can sign and verify. + pubKey := h.hostKey.PublicKey() + if pubKey == nil { + t.Fatal("expected non-nil public key") + } + + // Verify it's an Ed25519 key. + if pubKey.Type() != ssh.KeyAlgoED25519 { + t.Fatalf("expected Ed25519 key, got %s", pubKey.Type()) + } + + // Verify sign/verify works. + data := []byte("test data") + sig, err := h.hostKey.Sign(nil, data) + if err != nil { + t.Fatalf("failed to sign: %v", err) + } + + // Parse the SSH public key to get the underlying ed25519 key for verification. + parsedKey, err := ssh.ParsePublicKey(pubKey.Marshal()) + if err != nil { + t.Fatalf("failed to parse public key: %v", err) + } + + if err := parsedKey.Verify(data, sig); err != nil { + t.Fatalf("signature verification failed: %v", err) + } + + // Verify the underlying key type with comma-ok assertion. + cryptoKey, ok := parsedKey.(ssh.CryptoPublicKey) + if !ok { + t.Fatal("expected ssh.CryptoPublicKey") + } + cryptoPubKey := cryptoKey.CryptoPublicKey() + if _, ok := cryptoPubKey.(ed25519.PublicKey); !ok { + t.Fatalf("expected ed25519.PublicKey, got %T", cryptoPubKey) + } +} + +func TestNewHostSetsEndpointID(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := h.hostID + "-relay" + if h.endpointID != expected { + t.Fatalf("expected endpointID %q, got %q", expected, h.endpointID) + } +} + +func TestHostConnectReturnsErrWhenAlreadyConnected(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Simulate being connected by setting the ssh field. + h.ssh = &tunnelssh.HostSSHSession{} + + err = h.Connect(context.Background(), &Tunnel{}) + if !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("expected ErrAlreadyConnected, got %v", err) + } +} + +func TestHostWaitReturnsErrWhenNotConnected(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + err = h.Wait() + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } +} + +func TestHostCloseReturnsErrWhenNotConnected(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + err = h.Close() + if !errors.Is(err, ErrNotConnected) { + t.Fatalf("expected ErrNotConnected, got %v", err) + } +} + +// newMockManagementAPI creates an httptest.Server that mocks the tunnel management API. +// It handles UpdateTunnelEndpoint (PUT) and DeleteTunnelEndpoints (DELETE). +// The hostRelayURI is returned in the endpoint response. +func newMockManagementAPI(t *testing.T, hostRelayURI string) (*httptest.Server, *Manager) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle UpdateTunnelEndpoint (PUT /tunnels/{id}/endpoints/{endpointId}) + if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/endpoints/") { + endpoint := TunnelEndpoint{ + ID: "test-endpoint", + TunnelRelayTunnelEndpoint: TunnelRelayTunnelEndpoint{ + HostRelayURI: hostRelayURI, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(endpoint) + return + } + + // Handle DeleteTunnelEndpoints (DELETE /tunnels/{id}/endpoints/{endpointId}) + if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/endpoints/") { + w.WriteHeader(http.StatusOK) + return + } + + // Handle GetTunnel (GET /tunnels/{id}) + if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/tunnels/") && !strings.Contains(r.URL.Path, "/ports/") { + tunnel := Tunnel{ + Name: "test-tunnel", + Ports: []TunnelPort{ + {PortNumber: 8080}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tunnel) + return + } + + // Handle CreateTunnelPort (PUT /tunnels/{id}/ports/{portNumber}) + if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/ports/") { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TunnelPort{PortNumber: 8080}) + return + } + + // Handle DeleteTunnelPort (DELETE /tunnels/{id}/ports/{portNumber}) + if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/ports/") { + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusNotFound) + })) + + t.Cleanup(server.Close) + + serviceURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse mock server URL: %v", err) + } + + mgr := &Manager{ + tokenProvider: func() string { return "" }, + httpClient: &http.Client{}, + uri: serviceURL, + userAgents: []UserAgent{{Name: "test", Version: "1.0"}}, + apiVersion: "2023-09-27-preview", + } + + return server, mgr +} + +func newTestTunnel() *Tunnel { + return &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "test-token", + }, + } +} + +func TestHostConnectSuccessful(t *testing.T) { + // Start mock relay server. + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + // Create mock management API that returns the relay URL. + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer h.Close() + + // Verify the relay received the connection. + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Verify connection status is Connected. + if h.ConnectionStatus() != ConnectionStatusConnected { + t.Fatalf("expected ConnectionStatusConnected, got %v", h.ConnectionStatus()) + } +} + +func TestHostConnectRejectsInvalidToken(t *testing.T) { + // Relay expects a specific token. + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel correct-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + // Management API returns the relay URL. + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Tunnel has wrong token. + tunnel := &Tunnel{ + Name: "test-tunnel", + AccessTokens: map[TunnelAccessScope]string{ + TunnelAccessScopeHost: "wrong-token", + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err = h.Connect(ctx, tunnel) + if err == nil { + h.Close() + t.Fatal("expected error from Host.Connect with invalid token, got nil") + } +} + +func TestHostConnectHandlesRelayDisconnect(t *testing.T) { + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Close the relay server to simulate disconnect. + relayServer.Close() + + // Wait should return (not hang forever). + done := make(chan error, 1) + go func() { + done <- h.Wait() + }() + + select { + case <-done: + // Wait returned, which is what we expect. + case <-time.After(5 * time.Second): + t.Fatal("Host.Wait did not return after relay disconnect") + } +} + +func TestHostCloseDeletesEndpoint(t *testing.T) { + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + // Track whether DELETE endpoint was called. + deleteEndpointCalled := make(chan struct{}, 1) + mgmtServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/endpoints/") { + endpoint := TunnelEndpoint{ + ID: "test-endpoint", + TunnelRelayTunnelEndpoint: TunnelRelayTunnelEndpoint{ + HostRelayURI: relayServer.URL(), + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(endpoint) + return + } + if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/endpoints/") { + select { + case deleteEndpointCalled <- struct{}{}: + default: + } + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mgmtServer.Close() + + serviceURL, _ := url.Parse(mgmtServer.URL) + mgr := &Manager{ + tokenProvider: func() string { return "" }, + httpClient: &http.Client{}, + uri: serviceURL, + userAgents: []UserAgent{{Name: "test", Version: "1.0"}}, + apiVersion: "2023-09-27-preview", + } + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Close the host. + if err := h.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + // Verify DeleteTunnelEndpoints was called. + select { + case <-deleteEndpointCalled: + // Good, endpoint was deleted. + case <-time.After(5 * time.Second): + t.Fatal("Host.Close did not call DeleteTunnelEndpoints") + } +} + +func TestHostCloseIsIdempotent(t *testing.T) { + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // First close should succeed. + if err := h.Close(); err != nil { + t.Fatalf("first Host.Close failed: %v", err) + } + + // Second close should succeed (idempotent), not return ErrNotConnected. + if err := h.Close(); err != nil { + t.Fatalf("second Host.Close should be idempotent, got: %v", err) + } +} + +func TestHostConnectionStatusCallback(t *testing.T) { + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var mu sync.Mutex + var transitions []ConnectionStatus + h.ConnectionStatusChanged = func(prev, curr ConnectionStatus) { + mu.Lock() + transitions = append(transitions, curr) + mu.Unlock() + } + + // Verify initial status is None. + if h.ConnectionStatus() != ConnectionStatusNone { + t.Fatalf("expected ConnectionStatusNone, got %v", h.ConnectionStatus()) + } + + // Connect triggers Connecting -> Connected. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Close triggers Disconnected. + if err := h.Close(); err != nil { + t.Fatalf("Host.Close failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + expected := []ConnectionStatus{ + ConnectionStatusConnecting, + ConnectionStatusConnected, + ConnectionStatusDisconnected, + } + if len(transitions) != len(expected) { + t.Fatalf("expected %d transitions, got %d: %v", len(expected), len(transitions), transitions) + } + for i, s := range expected { + if transitions[i] != s { + t.Fatalf("transition[%d]: expected %v, got %v", i, s, transitions[i]) + } + } +} + +func TestHostTooManyConnectionsGuard(t *testing.T) { + mgr := newTestManager() + h, err := NewHost(nil, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Simulate a prior TooManyConnections disconnect. + h.disconnectReason = tunnelssh.SshDisconnectReasonTooManyConnections + + err = h.Connect(context.Background(), &Tunnel{}) + if !errors.Is(err, ErrTooManyConnections) { + t.Fatalf("expected ErrTooManyConnections, got %v", err) + } +} + +// tcpConnPair creates a pair of connected net.Conn via TCP loopback. +// Unlike net.Pipe(), TCP connections have kernel buffering so both sides +// can write concurrently without deadlocking during SSH handshakes. +func tcpConnPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + var serverErr error + done := make(chan struct{}) + go func() { + defer close(done) + serverConn, serverErr = ln.Accept() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + <-done + if serverErr != nil { + clientConn.Close() + t.Fatalf("failed to accept: %v", serverErr) + } + return clientConn, serverConn +} + +func TestHostAndClientIntegration(t *testing.T) { + // 1. Start a TCP echo server on a random localhost port. + echoListener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + defer echoListener.Close() + + echoPort := uint16(echoListener.Addr().(*net.TCPAddr).Port) + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + // 2. Create TCP loopback pair for host-relay connection. + hostEnd, relayEnd := tcpConnPair(t) + + // 3. Generate a host key. + _, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate host key: %v", err) + } + hostKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + + logger := log.New(os.Stderr, "e2e-test: ", log.LstdFlags) + + // 4. Create HostSSHSession (V2) and connect concurrently with the mock relay. + session := tunnelssh.NewHostSSHSession(hostEnd, hostKey, logger, "test-token", tunnelssh.HostWebSocketSubProtocolV2) + + var relay *tunnelstest.MockRelayForHost + var relayErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + relay, relayErr = tunnelstest.NewMockRelayForHost(relayEnd) + }() + + ctx := context.Background() + if err := session.Connect(ctx); err != nil { + t.Fatalf("failed to connect host session: %v", err) + } + wg.Wait() + if relayErr != nil { + t.Fatalf("relay SSH handshake failed: %v", relayErr) + } + defer func() { + session.Close() + relay.Close() + }() + + // 5. Add the echo port to the host — sends tcpip-forward to relay. + session.AddPort(echoPort, "test-token") + + // Give the relay time to process the tcpip-forward request. + time.Sleep(200 * time.Millisecond) + + // Verify relay registered the port. + if !relay.HasPort(echoPort) { + t.Fatalf("relay did not register port %d", echoPort) + } + + // 6. Simulate a client connection via the V2 mock relay. + clientConn, err := relay.SimulateClientConnection(echoPort) + if err != nil { + t.Fatalf("failed to simulate client connection: %v", err) + } + + // 7. In V2, the channel IS the data stream — no nested SSH. + // Send 'hello tunnel' through the tunnel and verify echo response. + testData := []byte("hello tunnel") + _, err = clientConn.Write(testData) + if err != nil { + t.Fatalf("failed to write through tunnel: %v", err) + } + + buf := make([]byte, len(testData)) + _, err = io.ReadFull(clientConn, buf) + if err != nil { + t.Fatalf("failed to read echo response: %v", err) + } + + if string(buf) != string(testData) { + t.Fatalf("data integrity check failed: sent %q, received %q", testData, buf) + } + + clientConn.Close() +} + +func TestHostRefreshPorts(t *testing.T) { + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + // Mock management API that returns different ports on GET vs what the host has locally. + remotePorts := []TunnelPort{ + {PortNumber: 3000}, + {PortNumber: 4000}, + } + mgmtServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // UpdateTunnelEndpoint + if r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/endpoints/") { + endpoint := TunnelEndpoint{ + ID: "test-endpoint", + TunnelRelayTunnelEndpoint: TunnelRelayTunnelEndpoint{ + HostRelayURI: relayServer.URL(), + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(endpoint) + return + } + // DeleteTunnelEndpoints + if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/endpoints/") { + w.WriteHeader(http.StatusOK) + return + } + // GetTunnel — returns the remote ports. + if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/tunnels/") { + tunnel := Tunnel{ + Name: "test-tunnel", + Ports: remotePorts, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tunnel) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mgmtServer.Close() + + serviceURL, _ := url.Parse(mgmtServer.URL) + mgr := &Manager{ + tokenProvider: func() string { return "" }, + httpClient: &http.Client{}, + uri: serviceURL, + userAgents: []UserAgent{{Name: "test", Version: "1.0"}}, + apiVersion: "2023-09-27-preview", + } + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer h.Close() + + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Locally add port 5000 via the SSH session (not on remote). + h.mu.Lock() + sshSession := h.ssh + h.mu.Unlock() + sshSession.AddPort(5000, "test-token") + + // Call RefreshPorts — should add 3000, 4000 and remove 5000. + if err := h.RefreshPorts(ctx); err != nil { + t.Fatalf("RefreshPorts failed: %v", err) + } + + if !sshSession.HasPort(3000) { + t.Fatal("expected port 3000 to be added by RefreshPorts") + } + if !sshSession.HasPort(4000) { + t.Fatal("expected port 4000 to be added by RefreshPorts") + } + if sshSession.HasPort(5000) { + t.Fatal("expected port 5000 to be removed by RefreshPorts") + } +} + +func TestHostConnectNegotiatesProtocol(t *testing.T) { + // Relay forces V1 protocol. + relayServer, err := tunnelstest.NewRelayHostServer( + tunnelstest.WithHostAccessToken("Tunnel test-token"), + tunnelstest.WithProtocolV1Only(), + ) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + _, mgr := newMockManagementAPI(t, relayServer.URL()) + + logger := log.New(os.Stderr, "test: ", log.LstdFlags) + h, err := NewHost(logger, mgr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tunnel := newTestTunnel() + if err := h.Connect(ctx, tunnel); err != nil { + t.Fatalf("Host.Connect failed: %v", err) + } + defer h.Close() + + if err := relayServer.WaitForConnection(5 * time.Second); err != nil { + t.Fatalf("relay did not receive connection: %v", err) + } + + // Verify the relay negotiated V1. + if relayServer.NegotiatedProtocol() != "tunnel-relay-host" { + t.Fatalf("expected V1 protocol, got %q", relayServer.NegotiatedProtocol()) + } + + // Verify the host session knows it's V1. + h.mu.Lock() + sshSession := h.ssh + h.mu.Unlock() + + if sshSession.ConnectionProtocol() != tunnelssh.HostWebSocketSubProtocol { + t.Fatalf("expected host session protocol %q, got %q", + tunnelssh.HostWebSocketSubProtocol, sshSession.ConnectionProtocol()) + } +} + +func TestHostAndClientIntegrationV1(t *testing.T) { + // 1. Start a TCP echo server on a random localhost port. + echoListener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + defer echoListener.Close() + + echoPort := uint16(echoListener.Addr().(*net.TCPAddr).Port) + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + // 2. Create TCP loopback pair for host-relay connection. + hostEnd, relayEnd := tcpConnPair(t) + + // 3. Generate a host key. + _, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate host key: %v", err) + } + hostKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + + logger := log.New(os.Stderr, "e2e-v1-test: ", log.LstdFlags) + + // 4. Create HostSSHSession (V1) and connect concurrently with the V1 mock relay. + session := tunnelssh.NewHostSSHSession(hostEnd, hostKey, logger, "", tunnelssh.HostWebSocketSubProtocol) + + var relay *tunnelstest.MockRelayForHostV1 + var relayErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + relay, relayErr = tunnelstest.NewMockRelayForHostV1(relayEnd) + }() + + ctx := context.Background() + if err := session.Connect(ctx); err != nil { + t.Fatalf("failed to connect host session: %v", err) + } + wg.Wait() + if relayErr != nil { + t.Fatalf("relay SSH handshake failed: %v", relayErr) + } + defer func() { + session.Close() + relay.Close() + }() + + // 5. Add the echo port to the host. + session.AddPort(echoPort, "") + + // 6. Simulate a V1 client connection via the mock relay. + client, err := relay.SimulateClientConnection() + if err != nil { + t.Fatalf("failed to simulate V1 client connection: %v", err) + } + defer client.Close() + + // Give time for port forward notifications. + time.Sleep(500 * time.Millisecond) + + // 7. Open a forwarded-tcpip channel from the client to the host. + // In V1, the client opens direct-tcpip to the host's nested SSH server. + ch, reqs, err := client.Conn.OpenChannel("direct-tcpip", ssh.Marshal(struct { + Host string + Port uint32 + OriginAddr string + OriginPort uint32 + }{ + Host: "127.0.0.1", + Port: uint32(echoPort), + OriginAddr: "127.0.0.1", + OriginPort: 0, + })) + if err != nil { + t.Fatalf("failed to open direct-tcpip channel: %v", err) + } + go ssh.DiscardRequests(reqs) + + // 8. Send data through the tunnel and verify echo. + testData := []byte("hello v1 tunnel") + _, err = ch.Write(testData) + if err != nil { + t.Fatalf("failed to write through V1 tunnel: %v", err) + } + + buf := make([]byte, len(testData)) + _, err = io.ReadFull(ch, buf) + if err != nil { + t.Fatalf("failed to read echo response: %v", err) + } + + if string(buf) != string(testData) { + t.Fatalf("V1 data integrity check failed: sent %q, received %q", testData, buf) + } + + ch.Close() +} diff --git a/go/tunnels/manager.go b/go/tunnels/manager.go index 9227fb76..7808ee1d 100644 --- a/go/tunnels/manager.go +++ b/go/tunnels/manager.go @@ -725,14 +725,12 @@ func (m *Manager) sendRequest( // Handle non 200s responses if result.StatusCode > 300 { - errorMessage, err := m.readProblemDetails(result) - if err == nil && errorMessage != nil { - return nil, fmt.Errorf("unsuccessful request, response: %d %s\n\t%s", - result.StatusCode, http.StatusText(result.StatusCode), *errorMessage) - } else { - return nil, fmt.Errorf("unsuccessful request, response: %d: %s", - result.StatusCode, http.StatusText(result.StatusCode)) + msg := fmt.Sprintf("%d %s", result.StatusCode, http.StatusText(result.StatusCode)) + errorMessage, pdErr := m.readProblemDetails(result) + if pdErr == nil && errorMessage != nil { + msg = fmt.Sprintf("%s\n\t%s", msg, *errorMessage) } + return nil, &TunnelError{StatusCode: result.StatusCode, Message: msg} } return io.ReadAll(result.Body) @@ -793,7 +791,7 @@ func (m *Manager) readProblemDetails(response *http.Response) (*string, error) { } func (m *Manager) getAccessToken(tunnel *Tunnel, tunnelRequestOptions *TunnelRequestOptions, accessTokenScopes []TunnelAccessScope) (token string) { - if tunnelRequestOptions.AccessToken != "" { + if tunnelRequestOptions != nil && tunnelRequestOptions.AccessToken != "" { token = fmt.Sprintf("%s %s", tunnelAuthenticationScheme, tunnelRequestOptions.AccessToken) } if token == "" { diff --git a/go/tunnels/socket.go b/go/tunnels/socket.go index 6d56c59c..4c90bd3d 100644 --- a/go/tunnels/socket.go +++ b/go/tunnels/socket.go @@ -81,6 +81,11 @@ func (s *socket) Write(b []byte) (int, error) { return bytesWritten, err } +// Subprotocol returns the negotiated WebSocket subprotocol. +func (s *socket) Subprotocol() string { + return s.conn.Subprotocol() +} + func (s *socket) Close() error { return s.conn.Close() } diff --git a/go/tunnels/ssh/client_ssh_stream.go b/go/tunnels/ssh/client_ssh_stream.go new file mode 100644 index 00000000..c840bd94 --- /dev/null +++ b/go/tunnels/ssh/client_ssh_stream.go @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelssh + +import ( + "net" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +// dummyAddr is a placeholder net.Addr used by channelConn. +type dummyAddr struct{} + +func (dummyAddr) Network() string { return "tunnel" } +func (dummyAddr) String() string { return "tunnel" } + +// channelConn wraps an ssh.Channel to implement the net.Conn interface. +// This is needed to pass an SSH channel to ssh.NewServerConn for +// per-client SSH sessions inside the relay tunnel. +type channelConn struct { + ssh.Channel + + mu sync.Mutex + timer *time.Timer +} + +// LocalAddr returns a dummy address. +func (c *channelConn) LocalAddr() net.Addr { + return dummyAddr{} +} + +// RemoteAddr returns a dummy address. +func (c *channelConn) RemoteAddr() net.Addr { + return dummyAddr{} +} + +// SetDeadline sets a deadline that closes the channel on expiration. +// This is required for SSH handshake timeouts. A zero-value time clears +// the timer without closing the channel. +func (c *channelConn) SetDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + + // Stop any existing timer to prevent leaks. + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + + // A zero-value deadline clears the timer. + if t.IsZero() { + return nil + } + + d := time.Until(t) + if d <= 0 { + // Deadline already passed, close immediately. + c.Channel.Close() + return nil + } + + c.timer = time.AfterFunc(d, func() { + c.Channel.Close() + }) + return nil +} + +// SetReadDeadline delegates to SetDeadline. +func (c *channelConn) SetReadDeadline(t time.Time) error { + return c.SetDeadline(t) +} + +// SetWriteDeadline is a no-op returning nil. +func (c *channelConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/go/tunnels/ssh/client_ssh_stream_test.go b/go/tunnels/ssh/client_ssh_stream_test.go new file mode 100644 index 00000000..dda56520 --- /dev/null +++ b/go/tunnels/ssh/client_ssh_stream_test.go @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelssh + +import ( + "io" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +// mockSSHChannel is a mock ssh.Channel for testing channelConn. +type mockSSHChannel struct { + readBuf []byte + writeBuf []byte + closed bool + mu sync.Mutex + readCh chan []byte + closeCh chan struct{} +} + +func newMockChannel() *mockSSHChannel { + return &mockSSHChannel{ + readCh: make(chan []byte, 10), + closeCh: make(chan struct{}), + } +} + +func (m *mockSSHChannel) Read(data []byte) (int, error) { + m.mu.Lock() + if m.closed { + m.mu.Unlock() + return 0, io.EOF + } + m.mu.Unlock() + + select { + case b := <-m.readCh: + n := copy(data, b) + return n, nil + case <-m.closeCh: + return 0, io.EOF + } +} + +func (m *mockSSHChannel) Write(data []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.closed { + return 0, io.ErrClosedPipe + } + m.writeBuf = append(m.writeBuf, data...) + return len(data), nil +} + +func (m *mockSSHChannel) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + if !m.closed { + m.closed = true + close(m.closeCh) + } + return nil +} + +func (m *mockSSHChannel) CloseWrite() error { return nil } + +func (m *mockSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return false, nil +} + +func (m *mockSSHChannel) Stderr() io.ReadWriter { + return nil +} + +func (m *mockSSHChannel) isClosed() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.closed +} + +// Ensure mockSSHChannel implements ssh.Channel. +var _ ssh.Channel = (*mockSSHChannel)(nil) + +func TestChannelConnLocalAddr(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + addr := conn.LocalAddr() + if addr.Network() != "tunnel" { + t.Fatalf("expected Network()=='tunnel', got %q", addr.Network()) + } + if addr.String() != "tunnel" { + t.Fatalf("expected String()=='tunnel', got %q", addr.String()) + } +} + +func TestChannelConnRemoteAddr(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + addr := conn.RemoteAddr() + if addr.Network() != "tunnel" { + t.Fatalf("expected Network()=='tunnel', got %q", addr.Network()) + } + if addr.String() != "tunnel" { + t.Fatalf("expected String()=='tunnel', got %q", addr.String()) + } +} + +func TestChannelConnSetDeadlineClosesChannel(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + // Set a short deadline. + err := conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Wait for the deadline to expire. + time.Sleep(100 * time.Millisecond) + + if !ch.isClosed() { + t.Fatal("expected channel to be closed after deadline expiration") + } +} + +func TestChannelConnSetDeadlineClearsTimer(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + // Set a deadline. + err := conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Clear the deadline with zero time. + err = conn.SetDeadline(time.Time{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Wait past the original deadline. + time.Sleep(100 * time.Millisecond) + + if ch.isClosed() { + t.Fatal("expected channel to stay open after clearing deadline") + } +} + +func TestChannelConnSetDeadlineResetsTimer(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + // Set a short deadline. + err := conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Reset to a longer deadline before the first expires. + err = conn.SetDeadline(time.Now().Add(500 * time.Millisecond)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Wait past the first deadline but before the second. + time.Sleep(100 * time.Millisecond) + + if ch.isClosed() { + t.Fatal("expected channel to stay open after resetting to longer deadline") + } + + // Clean up. + conn.SetDeadline(time.Time{}) +} + +func TestChannelConnReadDelegatesToChannel(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + expected := []byte("hello from channel") + ch.readCh <- expected + + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(buf[:n]) != string(expected) { + t.Fatalf("expected %q, got %q", expected, buf[:n]) + } +} + +func TestChannelConnWriteDelegatesToChannel(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + data := []byte("hello to channel") + n, err := conn.Write(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != len(data) { + t.Fatalf("expected %d bytes written, got %d", len(data), n) + } + + ch.mu.Lock() + defer ch.mu.Unlock() + if string(ch.writeBuf) != string(data) { + t.Fatalf("expected %q in write buffer, got %q", data, ch.writeBuf) + } +} + +func TestChannelConnSetWriteDeadlineIsNoop(t *testing.T) { + ch := newMockChannel() + conn := &channelConn{Channel: ch} + + err := conn.SetWriteDeadline(time.Now().Add(time.Second)) + if err != nil { + t.Fatalf("expected nil error from SetWriteDeadline, got %v", err) + } +} diff --git a/go/tunnels/ssh/host_port_forward_test.go b/go/tunnels/ssh/host_port_forward_test.go new file mode 100644 index 00000000..6c9d54d0 --- /dev/null +++ b/go/tunnels/ssh/host_port_forward_test.go @@ -0,0 +1,477 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelssh + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +// startEchoServer starts a TCP echo server on a random port and returns +// the listener and port number. The server echoes all received data back. +func startEchoServer(t *testing.T, network string, addr string) (net.Listener, uint16) { + t.Helper() + listener, err := net.Listen(network, addr) + if err != nil { + t.Fatalf("failed to start echo server: %v", err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + port := uint16(listener.Addr().(*net.TCPAddr).Port) + return listener, port +} + +func TestPortForwardingEchoServer(t *testing.T) { + // Start echo server. + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + // Setup host session with the port registered. + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + // Open a forwarded-tcpip channel from the relay. + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open forwarded-tcpip channel: %v", err) + } + defer ch.Close() + + // Send data through the tunnel. + testData := []byte("hello tunnel") + _, err = ch.Write(testData) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + // Read the echo response. + buf := make([]byte, len(testData)) + _, err = io.ReadFull(ch, buf) + if err != nil { + t.Fatalf("failed to read echo: %v", err) + } + + if string(buf) != string(testData) { + t.Fatalf("expected %q, got %q", testData, buf) + } +} + +func TestPortForwardingIPv4(t *testing.T) { + // Start echo server on IPv4. + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open channel: %v", err) + } + defer ch.Close() + + _, err = ch.Write([]byte("ipv4 test")) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + buf := make([]byte, 9) + _, err = io.ReadFull(ch, buf) + if err != nil { + t.Fatalf("failed to read: %v", err) + } + if string(buf) != "ipv4 test" { + t.Fatalf("expected 'ipv4 test', got %q", buf) + } +} + +func TestPortForwardingConnectionRefused(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Add a port that has no listener (connection will be refused). + session.AddPort(19999, "test-token") + relay.waitForPortForward(t, 19999) + + ch, err := relay.openForwardedTCPIP(19999) + if err != nil { + // Channel may be rejected, which is acceptable. + return + } + + // The channel should be closed by the host after connection refused. + buf := make([]byte, 1) + _, err = ch.Read(buf) + if err == nil { + t.Fatal("expected error reading from channel with no backend") + } +} + +func TestPortForwardingBidirectional(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open channel: %v", err) + } + defer ch.Close() + + // Send multiple messages and verify echo. + for i := 0; i < 5; i++ { + msg := fmt.Sprintf("message %d", i) + _, err = ch.Write([]byte(msg)) + if err != nil { + t.Fatalf("write %d failed: %v", i, err) + } + + buf := make([]byte, len(msg)) + _, err = io.ReadFull(ch, buf) + if err != nil { + t.Fatalf("read %d failed: %v", i, err) + } + if string(buf) != msg { + t.Fatalf("message %d: expected %q, got %q", i, msg, buf) + } + } +} + +func TestDirectTcpipRejectsUnregisteredPort(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + // Register only echoPort — the unregistered port should be rejected. + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + // Try to open a direct-tcpip channel to an unregistered port. + unregisteredPort := echoPort + 1000 + _, err := relay.openDirectTCPIP(unregisteredPort) + if err == nil { + t.Fatal("expected direct-tcpip to unregistered port to be rejected, got nil error") + } + + // Verify it's a rejection with Prohibited reason. + if openErr, ok := err.(*ssh.OpenChannelError); ok { + if openErr.Reason != ssh.Prohibited { + t.Fatalf("expected Prohibited rejection reason, got %v", openErr.Reason) + } + } else { + t.Fatalf("expected *ssh.OpenChannelError, got %T: %v", err, err) + } +} + +func TestDirectTcpipChannel(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openDirectTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open direct-tcpip: %v", err) + } + defer ch.Close() + + msg := []byte("direct-tcpip-test") + if _, err := ch.Write(msg); err != nil { + t.Fatalf("failed to write: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(ch, buf); err != nil { + t.Fatalf("failed to read: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("expected %q, got %q", msg, buf) + } +} + +func TestPortForwardingIPv6Fallback(t *testing.T) { + // Start echo server on IPv6 only. + echoListener, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available") + } + echoPort := uint16(echoListener.Addr().(*net.TCPAddr).Port) + + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open channel: %v", err) + } + defer ch.Close() + + testData := []byte("ipv6 test") + _, err = ch.Write(testData) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + buf := make([]byte, len(testData)) + _, err = io.ReadFull(ch, buf) + if err != nil { + t.Fatalf("failed to read: %v", err) + } + if string(buf) != string(testData) { + t.Fatalf("expected %q, got %q", testData, buf) + } +} + +func TestPortForwardingLargePayload(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open channel: %v", err) + } + + // Send 1MB of data. + payload := make([]byte, 1024*1024) + for i := range payload { + payload[i] = byte(i % 256) + } + expectedHash := sha256.Sum256(payload) + + var wg sync.WaitGroup + var writeErr error + + wg.Add(1) + go func() { + defer wg.Done() + _, writeErr = ch.Write(payload) + ch.CloseWrite() + }() + + // Read all echoed data. + received := new(bytes.Buffer) + _, err = io.Copy(received, ch) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("failed to read echo: %v", err) + } + + wg.Wait() + if writeErr != nil { + t.Fatalf("write error: %v", writeErr) + } + + actualHash := sha256.Sum256(received.Bytes()) + if actualHash != expectedHash { + t.Fatalf("SHA256 mismatch: payload integrity check failed (sent %d bytes, received %d bytes)", + len(payload), received.Len()) + } +} + +func TestPortForwardingConcurrentConnections(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + var wg sync.WaitGroup + errc := make(chan error, 5) + + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + errc <- fmt.Errorf("connection %d: open channel failed: %w", idx, err) + return + } + defer ch.Close() + + msg := fmt.Sprintf("concurrent-%d", idx) + _, err = ch.Write([]byte(msg)) + if err != nil { + errc <- fmt.Errorf("connection %d: write failed: %w", idx, err) + return + } + + buf := make([]byte, len(msg)) + _, err = io.ReadFull(ch, buf) + if err != nil { + errc <- fmt.Errorf("connection %d: read failed: %w", idx, err) + return + } + if string(buf) != msg { + errc <- fmt.Errorf("connection %d: expected %q, got %q", idx, msg, buf) + } + }(i) + } + + wg.Wait() + close(errc) + for err := range errc { + t.Fatal(err) + } +} + +func TestAddPortNotifiesRelay(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Start an echo server. + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + // Add a port — relay should receive tcpip-forward. + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + // Verify relay knows about the port. + if !relay.hasPort(echoPort) { + t.Fatal("relay should have the port registered") + } +} + +func TestRemovePortNotifiesRelay(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Add a port first. + session.AddPort(9090, "test-token") + relay.waitForPortForward(t, 9090) + + // Now remove the port — relay should receive cancel-tcpip-forward. + session.RemovePort(9090, "test-token") + + // Wait for cancel request. + for { + select { + case info := <-relay.portReqs: + if info.reqType == "cancel-tcpip-forward" && info.port == 9090 { + // Verify the port was removed from the session's port list. + session.portsMu.RLock() + for _, p := range session.ports { + if p == 9090 { + session.portsMu.RUnlock() + t.Fatal("port 9090 should have been removed") + } + } + session.portsMu.RUnlock() + return + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for cancel-tcpip-forward notification") + return + } + } +} + +func TestForwardedTcpipChannel(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open forwarded-tcpip: %v", err) + } + defer ch.Close() + + msg := []byte("forwarded-tcpip-test") + if _, err := ch.Write(msg); err != nil { + t.Fatalf("failed to write: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(ch, buf); err != nil { + t.Fatalf("failed to read: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("expected %q, got %q", msg, buf) + } +} + +func TestClientDisconnectMidTransfer(t *testing.T) { + echoListener, echoPort := startEchoServer(t, "tcp4", "127.0.0.1:0") + defer echoListener.Close() + + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + session.AddPort(echoPort, "test-token") + relay.waitForPortForward(t, echoPort) + + // Open a forwarded-tcpip channel. + ch, err := relay.openForwardedTCPIP(echoPort) + if err != nil { + t.Fatalf("failed to open channel: %v", err) + } + + // Write some data. + ch.Write([]byte("partial data")) + + // Abruptly close the channel mid-transfer. + ch.Close() + + // Wait a bit to ensure no panic or goroutine leak crashes. + time.Sleep(200 * time.Millisecond) +} diff --git a/go/tunnels/ssh/host_session.go b/go/tunnels/ssh/host_session.go new file mode 100644 index 00000000..9a173547 --- /dev/null +++ b/go/tunnels/ssh/host_session.go @@ -0,0 +1,573 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelssh + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "regexp" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages" + "golang.org/x/crypto/ssh" +) + +// forwardedTCPIPData is the channel extra data for forwarded-tcpip and direct-tcpip channels. +type forwardedTCPIPData struct { + Host string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +const ( + sshUser = "tunnel" + localHostIPv4 = "127.0.0.1" + localHostIPv6 = "[::1]" + + sshHandshakeTimeout = 10 * time.Second + tcpDialTimeout = 5 * time.Second + + // SshDisconnectReasonTooManyConnections is the SSH disconnect reason code + // sent by the relay when too many host connections are active. + SshDisconnectReasonTooManyConnections uint32 = 11 + + // HostWebSocketSubProtocol is the V1 WebSocket subprotocol for host relay connections. + HostWebSocketSubProtocol = "tunnel-relay-host" + + // HostWebSocketSubProtocolV2 is the V2 WebSocket subprotocol for host relay connections. + HostWebSocketSubProtocolV2 = "tunnel-relay-host-v2-dev" + + // ClientStreamChannelType is the SSH channel type used in V1 for client session streams. + ClientStreamChannelType = "client-ssh-session-stream" +) + +// disconnectReasonRe matches the x/crypto/ssh disconnect message format. +var disconnectReasonRe = regexp.MustCompile(`ssh: disconnect, reason (\d+):`) + +// HostSSHSession manages the SSH connection from the host to the relay. +// +// In V2 (tunnel-relay-host-v2-dev), the host sends tcpip-forward requests +// directly to the relay (with an access token), and the relay opens +// forwarded-tcpip channels directly to the host — no nested SSH sessions. +// +// In V1 (tunnel-relay-host), the relay opens client-ssh-session-stream +// channels. Each channel carries a nested SSH connection from a client. +// The host performs a nested SSH server handshake on each channel, then +// forwards ports to each client individually via tcpip-forward requests +// on the nested connection. +// +// Locking strategy: connMu, portsMu, and clientsMu guard independent +// state and are never held simultaneously. +type HostSSHSession struct { + conn net.Conn + hostKey ssh.Signer + logger *log.Logger + accessToken string + connectionProtocol string + + // connMu guards sshConn and disconnectReason. + connMu sync.Mutex + sshConn ssh.Conn + disconnectReason uint32 + + portsMu sync.RWMutex + ports []uint16 + + // V1 per-client SSH connections. + clientsMu sync.RWMutex + clients map[string]*ssh.ServerConn + clientCounter uint32 // atomic +} + +// NewHostSSHSession creates a new HostSSHSession. +func NewHostSSHSession(conn net.Conn, hostKey ssh.Signer, logger *log.Logger, accessToken string, connectionProtocol string) *HostSSHSession { + return &HostSSHSession{ + conn: conn, + hostKey: hostKey, + logger: logger, + accessToken: accessToken, + connectionProtocol: connectionProtocol, + clients: make(map[string]*ssh.ServerConn), + } +} + +// ConnectionProtocol returns the negotiated WebSocket subprotocol. +func (s *HostSSHSession) ConnectionProtocol() string { + return s.connectionProtocol +} + +// Connect establishes the SSH client connection to the relay. +func (s *HostSSHSession) Connect(ctx context.Context) error { + clientConfig := ssh.ClientConfig{ + User: sshUser, + Timeout: sshHandshakeTimeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + sshConn, chans, reqs, err := ssh.NewClientConn(s.conn, "", &clientConfig) + if err != nil { + return fmt.Errorf("error creating SSH client connection to relay: %w", err) + } + + s.connMu.Lock() + s.sshConn = sshConn + s.connMu.Unlock() + + go s.handleGlobalRequests(reqs) + go s.handleIncomingChannels(ctx, chans) + + return nil +} + +// Wait blocks until the relay SSH connection drops. +func (s *HostSSHSession) Wait() error { + s.connMu.Lock() + conn := s.sshConn + s.connMu.Unlock() + + if conn == nil { + return fmt.Errorf("not connected") + } + + err := conn.Wait() + + s.connMu.Lock() + s.disconnectReason = parseDisconnectReason(err) + s.connMu.Unlock() + + return err +} + +// Close closes the SSH connection to the relay. +func (s *HostSSHSession) Close() error { + s.connMu.Lock() + conn := s.sshConn + s.connMu.Unlock() + + // Close all V1 client connections. + s.clientsMu.Lock() + for id, client := range s.clients { + client.Close() + delete(s.clients, id) + } + s.clientsMu.Unlock() + + if conn == nil { + return nil + } + return conn.Close() +} + +// DisconnectReason returns the SSH disconnect reason code from the last +// disconnection, or 0 if not disconnected or reason unknown. +func (s *HostSSHSession) DisconnectReason() uint32 { + s.connMu.Lock() + defer s.connMu.Unlock() + return s.disconnectReason +} + +// parseDisconnectReason extracts the SSH disconnect reason code from an error +// returned by x/crypto/ssh. Returns 0 if the reason cannot be determined. +func parseDisconnectReason(err error) uint32 { + if err == nil { + return 0 + } + m := disconnectReasonRe.FindStringSubmatch(err.Error()) + if len(m) < 2 { + return 0 + } + reason, convErr := strconv.ParseUint(m[1], 10, 32) + if convErr != nil { + return 0 + } + return uint32(reason) +} + +// handleGlobalRequests rejects unknown global requests from the relay. +func (s *HostSSHSession) handleGlobalRequests(reqs <-chan *ssh.Request) { + for r := range reqs { + r.Reply(false, nil) + } +} + +// handleIncomingChannels dispatches incoming channels from the relay +// based on channel type and connection protocol. +func (s *HostSSHSession) handleIncomingChannels(ctx context.Context, chans <-chan ssh.NewChannel) { + for newChan := range chans { + switch newChan.ChannelType() { + case ClientStreamChannelType: + if s.connectionProtocol == HostWebSocketSubProtocol { + go s.handleClientSession(ctx, newChan) + } else { + newChan.Reject(ssh.UnknownChannelType, "unknown channel type") + } + case "forwarded-tcpip", "direct-tcpip": + go s.handleV2Channel(ctx, newChan) + default: + newChan.Reject(ssh.UnknownChannelType, "unknown channel type") + } + } +} + +// handleClientSession handles a V1 client-ssh-session-stream channel. +// It accepts the channel, wraps it as a net.Conn, performs a nested SSH +// server handshake, forwards existing ports, and handles client channels. +func (s *HostSSHSession) handleClientSession(ctx context.Context, newChan ssh.NewChannel) { + channel, reqs, err := newChan.Accept() + if err != nil { + s.logger.Printf("failed to accept client-ssh-session-stream: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + // Wrap the channel as a net.Conn for the nested SSH handshake. + conn := &channelConn{Channel: channel} + + // Nested SSH server config — NoClientAuth since the relay already authenticated. + serverConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(s.hostKey) + + serverConn, clientChans, clientReqs, err := ssh.NewServerConn(conn, serverConfig) + if err != nil { + s.logger.Printf("nested SSH handshake failed: %v", err) + channel.Close() + return + } + + // Generate a unique client ID and store the connection. + clientID := fmt.Sprintf("client-%d", atomic.AddUint32(&s.clientCounter, 1)) + s.clientsMu.Lock() + s.clients[clientID] = serverConn + s.clientsMu.Unlock() + + s.logger.Printf("V1 client connected: %s", clientID) + + // Forward existing ports to this client. + s.portsMu.RLock() + currentPorts := make([]uint16, len(s.ports)) + copy(currentPorts, s.ports) + s.portsMu.RUnlock() + + for _, port := range currentPorts { + s.forwardPortToClient(serverConn, port) + } + + // Handle global requests from the client (tcpip-forward, cancel-tcpip-forward). + go s.handleClientGlobalRequests(clientReqs) + + // Handle channels from the client. + go s.handleClientChannels(ctx, serverConn, clientChans) + + // Wait for the client connection to close, then clean up. + serverConn.Wait() + + s.clientsMu.Lock() + delete(s.clients, clientID) + s.clientsMu.Unlock() + + s.logger.Printf("V1 client disconnected: %s", clientID) +} + +// handleClientGlobalRequests handles global requests from a V1 client. +// In V1, clients may send tcpip-forward requests which we accept. +func (s *HostSSHSession) handleClientGlobalRequests(reqs <-chan *ssh.Request) { + for req := range reqs { + switch req.Type { + case "tcpip-forward": + // Accept port forward requests from clients. + req.Reply(true, nil) + case "cancel-tcpip-forward": + req.Reply(true, nil) + default: + req.Reply(false, nil) + } + } +} + +// forwardPortToClient sends a tcpip-forward request to a V1 client's +// nested SSH connection, notifying the client that a port is available. +func (s *HostSSHSession) forwardPortToClient(serverConn *ssh.ServerConn, port uint16) { + pfr := messages.NewPortForwardRequest(localHostIPv4, uint32(port)) + b, err := pfr.Marshal() + if err != nil { + s.logger.Printf("error marshaling port forward request for port %d: %v", port, err) + return + } + + _, _, err = serverConn.SendRequest(messages.PortForwardRequestType, true, b) + if err != nil { + s.logger.Printf("error sending tcpip-forward to client for port %d: %v", port, err) + } +} + +// cancelForwardPortToClient sends a cancel-tcpip-forward request to a V1 +// client's nested SSH connection. +func (s *HostSSHSession) cancelForwardPortToClient(serverConn *ssh.ServerConn, port uint16) { + pfcr := messages.NewPortForwardCancelRequest(localHostIPv4, uint32(port)) + b, err := pfcr.Marshal() + if err != nil { + s.logger.Printf("error marshaling cancel port forward request for port %d: %v", port, err) + return + } + + _, _, err = serverConn.SendRequest(messages.PortForwardCancelRequestType, true, b) + if err != nil { + s.logger.Printf("error sending cancel-tcpip-forward to client for port %d: %v", port, err) + } +} + +// handleClientChannels dispatches channels from a V1 client's nested SSH connection. +func (s *HostSSHSession) handleClientChannels(ctx context.Context, serverConn *ssh.ServerConn, chans <-chan ssh.NewChannel) { + for newChan := range chans { + switch newChan.ChannelType() { + case "forwarded-tcpip", "direct-tcpip": + // Parse the standard forwarded-tcpip extra data to get the port. + var data forwardedTCPIPData + if err := ssh.Unmarshal(newChan.ExtraData(), &data); err != nil { + newChan.Reject(ssh.ConnectionFailed, "invalid channel data") + continue + } + + if !s.HasPort(uint16(data.Port)) { + s.logger.Printf("V1 client: rejected %s to unregistered port %d", newChan.ChannelType(), data.Port) + newChan.Reject(ssh.Prohibited, "port not registered") + continue + } + + go s.handleForwardedTCPIP(ctx, uint16(data.Port), newChan) + default: + newChan.Reject(ssh.UnknownChannelType, "unknown channel type") + } + } +} + +// handleV2Channel handles a forwarded-tcpip or direct-tcpip channel from +// the V2 relay. It parses the V2 extra data (with access token and E2E +// encryption flag), validates the port, and proxies to local TCP. +func (s *HostSSHSession) handleV2Channel(ctx context.Context, newChan ssh.NewChannel) { + // Parse channel extra data. Try V2 format first (has additional fields), + // fall back to standard forwarded-tcpip format. + extraData := newChan.ExtraData() + var port uint32 + + var v2Data messages.PortRelayConnectRequest + if err := v2Data.Unmarshal(bytes.NewReader(extraData)); err == nil { + port = v2Data.Port + } else { + // Fall back to standard format. + var data forwardedTCPIPData + if err := ssh.Unmarshal(extraData, &data); err != nil { + newChan.Reject(ssh.ConnectionFailed, "invalid channel data") + return + } + port = data.Port + } + + if !s.HasPort(uint16(port)) { + s.logger.Printf("rejected %s to unregistered port %d", newChan.ChannelType(), port) + newChan.Reject(ssh.Prohibited, "port not registered") + return + } + + s.handleForwardedTCPIP(ctx, uint16(port), newChan) +} + +// HasPort reports whether the given port is in the forwarded ports list. +func (s *HostSSHSession) HasPort(port uint16) bool { + s.portsMu.RLock() + defer s.portsMu.RUnlock() + for _, p := range s.ports { + if p == port { + return true + } + } + return false +} + +// Ports returns a copy of the currently registered port list. +func (s *HostSSHSession) Ports() []uint16 { + s.portsMu.RLock() + defer s.portsMu.RUnlock() + result := make([]uint16, len(s.ports)) + copy(result, s.ports) + return result +} + +// handleForwardedTCPIP proxies a forwarded-tcpip or direct-tcpip channel +// to a local TCP port. +func (s *HostSSHSession) handleForwardedTCPIP(ctx context.Context, port uint16, newChan ssh.NewChannel) { + channel, reqs, err := newChan.Accept() + if err != nil { + s.logger.Printf("failed to accept forwarded-tcpip channel for port %d: %v", port, err) + return + } + go ssh.DiscardRequests(reqs) + + // Dial local TCP: try IPv4 first, fall back to IPv6. + var tcpConn net.Conn + dialer := net.Dialer{Timeout: tcpDialTimeout} + tcpConn, err = dialer.DialContext(ctx, "tcp4", fmt.Sprintf("%s:%d", localHostIPv4, port)) + if err != nil { + tcpConn, err = dialer.DialContext(ctx, "tcp6", fmt.Sprintf("%s:%d", localHostIPv6, port)) + if err != nil { + s.logger.Printf("failed to dial local port %d: %v", port, err) + channel.Close() + return + } + } + + // Bidirectional io.Copy between the SSH channel and the TCP connection. + // When one direction finishes, close the write side of the destination + // to propagate EOF, then wait for the reverse direction to drain. + errs := make(chan error, 2) + + go func() { + _, err := io.Copy(tcpConn, channel) + // Signal EOF to the local service so it stops reading. + if tc, ok := tcpConn.(*net.TCPConn); ok { + tc.CloseWrite() + } + errs <- err + }() + + go func() { + _, err := io.Copy(channel, tcpConn) + channel.CloseWrite() + errs <- err + }() + + // Wait for both directions to complete or context cancellation. + select { + case <-ctx.Done(): + case <-errs: + // One direction done — wait for the other to drain. + select { + case <-ctx.Done(): + case <-errs: + } + } + + channel.Close() + tcpConn.Close() +} + +// AddPort adds a port to the port list and notifies the relay or connected clients. +// If the port is already registered, this is a no-op. +// +// In V2, sends a tcpip-forward global request to the relay with the access token. +// In V1, iterates connected clients and sends tcpip-forward to each. +func (s *HostSSHSession) AddPort(port uint16, accessToken string) { + s.portsMu.Lock() + for _, p := range s.ports { + if p == port { + s.portsMu.Unlock() + return + } + } + s.ports = append(s.ports, port) + s.portsMu.Unlock() + + if s.connectionProtocol == HostWebSocketSubProtocol { + // V1: notify each connected client. + s.clientsMu.RLock() + clients := make([]*ssh.ServerConn, 0, len(s.clients)) + for _, c := range s.clients { + clients = append(clients, c) + } + s.clientsMu.RUnlock() + + for _, client := range clients { + s.forwardPortToClient(client, port) + } + return + } + + // V2: send tcpip-forward to the relay. + s.connMu.Lock() + conn := s.sshConn + s.connMu.Unlock() + + if conn == nil { + return + } + + prr := messages.NewPortRelayRequest(localHostIPv4, uint32(port), accessToken) + b, err := prr.Marshal() + if err != nil { + s.logger.Printf("error marshaling port relay request for port %d: %v", port, err) + return + } + + _, _, err = conn.SendRequest(messages.PortForwardRequestType, true, b) + if err != nil { + s.logger.Printf("error sending tcpip-forward for port %d: %v", port, err) + } +} + +// RemovePort removes a port from the port list and notifies the relay or connected clients. +// +// In V2, sends a cancel-tcpip-forward global request to the relay. +// In V1, iterates connected clients and sends cancel-tcpip-forward to each. +func (s *HostSSHSession) RemovePort(port uint16, accessToken string) { + s.portsMu.Lock() + found := false + for i, p := range s.ports { + if p == port { + s.ports = append(s.ports[:i], s.ports[i+1:]...) + found = true + break + } + } + s.portsMu.Unlock() + + if !found { + return + } + + if s.connectionProtocol == HostWebSocketSubProtocol { + // V1: notify each connected client. + s.clientsMu.RLock() + clients := make([]*ssh.ServerConn, 0, len(s.clients)) + for _, c := range s.clients { + clients = append(clients, c) + } + s.clientsMu.RUnlock() + + for _, client := range clients { + s.cancelForwardPortToClient(client, port) + } + return + } + + // V2: send cancel-tcpip-forward to the relay. + s.connMu.Lock() + conn := s.sshConn + s.connMu.Unlock() + + if conn == nil { + return + } + + prr := messages.NewPortRelayRequest(localHostIPv4, uint32(port), accessToken) + b, err := prr.Marshal() + if err != nil { + s.logger.Printf("error marshaling cancel port relay request for port %d: %v", port, err) + return + } + + _, _, err = conn.SendRequest(messages.PortForwardCancelRequestType, true, b) + if err != nil { + s.logger.Printf("error sending cancel-tcpip-forward for port %d: %v", port, err) + } +} diff --git a/go/tunnels/ssh/host_session_test.go b/go/tunnels/ssh/host_session_test.go new file mode 100644 index 00000000..2f85a116 --- /dev/null +++ b/go/tunnels/ssh/host_session_test.go @@ -0,0 +1,769 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelssh + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/rand" + "log" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages" + "golang.org/x/crypto/ssh" +) + +const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV +rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY +lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB +AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0 +5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8 +IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC +46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n +IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC +t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz +J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA +hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV +933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP +FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw== +-----END RSA PRIVATE KEY-----` + +func newTestHostKey() ssh.Signer { + _, privKey, _ := ed25519.GenerateKey(rand.Reader) + signer, _ := ssh.NewSignerFromKey(privKey) + return signer +} + +func newTestLogger() *log.Logger { + return log.New(os.Stderr, "test: ", log.LstdFlags) +} + +// tcpConnPair creates a pair of connected net.Conn via TCP loopback. +// Unlike net.Pipe(), TCP connections have kernel buffering so both sides +// can write concurrently without deadlocking during SSH handshakes. +func tcpConnPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + var serverErr error + done := make(chan struct{}) + go func() { + defer close(done) + serverConn, serverErr = ln.Accept() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + <-done + if serverErr != nil { + clientConn.Close() + t.Fatalf("failed to accept: %v", serverErr) + } + return clientConn, serverConn +} + +// mockV2Relay holds the relay side of an SSH connection that handles +// tcpip-forward/cancel-tcpip-forward global requests like a V2 relay. +type mockV2Relay struct { + conn ssh.Conn + mu sync.Mutex + ports map[uint16]struct{} + portReqs chan portReqInfo +} + +type portReqInfo struct { + reqType string + port uint16 +} + +// setupHostAndRelay creates a HostSSHSession connected via TCP loopback +// with a V2 mock relay that handles tcpip-forward requests. +// Returns the host session, the mock relay, and a cleanup function. +func setupHostAndRelay(t *testing.T) (*HostSSHSession, *mockV2Relay, func()) { + t.Helper() + + hostEnd, relayEnd := tcpConnPair(t) + hostKey := newTestHostKey() + logger := newTestLogger() + + session := NewHostSSHSession(hostEnd, hostKey, logger, "test-token", HostWebSocketSubProtocolV2) + + // Relay side: SSH server + relayConfig := &ssh.ServerConfig{NoClientAuth: true} + privateKey, err := ssh.ParsePrivateKey([]byte(testRSAPrivateKey)) + if err != nil { + t.Fatalf("failed to parse private key: %v", err) + } + relayConfig.AddHostKey(privateKey) + + relay := &mockV2Relay{ + ports: make(map[uint16]struct{}), + portReqs: make(chan portReqInfo, 100), + } + + // Connect concurrently: host connects as SSH client, relay accepts as SSH server. + var relayConn ssh.Conn + var relayErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var reqs <-chan *ssh.Request + var chans <-chan ssh.NewChannel + relayConn, chans, reqs, relayErr = ssh.NewServerConn(relayEnd, relayConfig) + if relayErr == nil { + go relay.handleGlobalRequests(reqs) + // Reject any channels from the host (V2 relay doesn't expect host-initiated channels). + go func() { + for ch := range chans { + ch.Reject(ssh.Prohibited, "not supported") + } + }() + } + }() + + ctx := context.Background() + if err := session.Connect(ctx); err != nil { + t.Fatalf("failed to connect host session: %v", err) + } + + wg.Wait() + if relayErr != nil { + t.Fatalf("relay SSH handshake failed: %v", relayErr) + } + + relay.conn = relayConn + + cleanup := func() { + session.Close() + relayConn.Close() + } + + return session, relay, cleanup +} + +func (r *mockV2Relay) handleGlobalRequests(reqs <-chan *ssh.Request) { + for req := range reqs { + switch req.Type { + case "tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + port := uint16(prr.Port()) + r.mu.Lock() + r.ports[port] = struct{}{} + r.mu.Unlock() + r.portReqs <- portReqInfo{reqType: "tcpip-forward", port: port} + req.Reply(true, nil) + + case "cancel-tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + port := uint16(prr.Port()) + r.mu.Lock() + delete(r.ports, port) + r.mu.Unlock() + r.portReqs <- portReqInfo{reqType: "cancel-tcpip-forward", port: port} + req.Reply(true, nil) + + default: + req.Reply(false, nil) + } + } +} + +// openForwardedTCPIP opens a forwarded-tcpip channel to the host with V2 extra data. +func (r *mockV2Relay) openForwardedTCPIP(port uint16) (ssh.Channel, error) { + data := &messages.PortRelayConnectRequest{ + Host: "127.0.0.1", + Port: uint32(port), + OriginatorIP: "127.0.0.1", + OriginatorPort: 0, + AccessToken: "", + IsE2EEncryptionRequested: false, + } + extraData, err := data.Marshal() + if err != nil { + return nil, err + } + + ch, reqs, err := r.conn.OpenChannel("forwarded-tcpip", extraData) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(reqs) + return ch, nil +} + +// openDirectTCPIP opens a direct-tcpip channel to the host with V2 extra data. +func (r *mockV2Relay) openDirectTCPIP(port uint16) (ssh.Channel, error) { + data := &messages.PortRelayConnectRequest{ + Host: "127.0.0.1", + Port: uint32(port), + OriginatorIP: "127.0.0.1", + OriginatorPort: 0, + AccessToken: "", + IsE2EEncryptionRequested: false, + } + extraData, err := data.Marshal() + if err != nil { + return nil, err + } + + ch, reqs, err := r.conn.OpenChannel("direct-tcpip", extraData) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(reqs) + return ch, nil +} + +func (r *mockV2Relay) hasPort(port uint16) bool { + r.mu.Lock() + defer r.mu.Unlock() + _, ok := r.ports[port] + return ok +} + +// waitForPortForward waits until a tcpip-forward request for the given port arrives. +func (r *mockV2Relay) waitForPortForward(t *testing.T, port uint16) { + t.Helper() + timeout := time.After(5 * time.Second) + for { + select { + case info := <-r.portReqs: + if info.reqType == "tcpip-forward" && info.port == port { + return + } + case <-timeout: + t.Fatalf("timeout waiting for tcpip-forward for port %d", port) + } + } +} + +func TestHostSessionConnectWithSSHHandshake(t *testing.T) { + _, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Verify the relay connection user is "tunnel". + if relay.conn.User() != "tunnel" { + t.Fatalf("expected user 'tunnel', got %q", relay.conn.User()) + } +} + +func TestHostSessionAcceptsForwardedTcpip(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Register a port so the channel will be accepted. + session.AddPort(8080, "test-token") + relay.waitForPortForward(t, 8080) + + // Open a forwarded-tcpip channel from the relay side. + ch, err := relay.openForwardedTCPIP(8080) + if err != nil { + // The channel should be accepted (port is registered). + // It's OK if the local dial fails — the channel was at least accepted. + return + } + ch.Close() +} + +func TestHostSessionRejectsUnknownChannelType(t *testing.T) { + _, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Try to open an unknown channel type. + _, _, err := relay.conn.OpenChannel("unknown-type", nil) + if err == nil { + t.Fatal("expected error for unknown channel type, got nil") + } + + // Verify it's an OpenChannelError with UnknownChannelType reason. + if openErr, ok := err.(*ssh.OpenChannelError); ok { + if openErr.Reason != ssh.UnknownChannelType { + t.Fatalf("expected UnknownChannelType, got %v", openErr.Reason) + } + } else { + t.Fatalf("expected *ssh.OpenChannelError, got %T: %v", err, err) + } +} + +func TestHostSessionRejectsUnregisteredPort(t *testing.T) { + _, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Try to open a forwarded-tcpip to an unregistered port. + _, err := relay.openForwardedTCPIP(9999) + if err == nil { + t.Fatal("expected error for unregistered port, got nil") + } + + // Verify it's a Prohibited rejection. + if openErr, ok := err.(*ssh.OpenChannelError); ok { + if openErr.Reason != ssh.Prohibited { + t.Fatalf("expected Prohibited, got %v", openErr.Reason) + } + } else { + t.Fatalf("expected *ssh.OpenChannelError, got %T: %v", err, err) + } +} + +func TestConcurrentAddPort(t *testing.T) { + session, _, cleanup := setupHostAndRelay(t) + defer cleanup() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(port uint16) { + defer wg.Done() + session.AddPort(port, "test-token") + }(uint16(9000 + i)) + } + wg.Wait() + + // Verify all 10 ports were added. + session.portsMu.RLock() + count := len(session.ports) + session.portsMu.RUnlock() + + if count != 10 { + t.Fatalf("expected 10 ports, got %d", count) + } +} + +func TestAddPortSendsTcpipForward(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Add ports before and verify relay receives tcpip-forward. + session.AddPort(8080, "test-token") + session.AddPort(3000, "test-token") + + // Wait for both port forward requests. + timeout := time.After(5 * time.Second) + received := 0 + for received < 2 { + select { + case info := <-relay.portReqs: + if info.reqType != "tcpip-forward" { + t.Fatalf("expected tcpip-forward, got %s", info.reqType) + } + received++ + case <-timeout: + t.Fatalf("timeout waiting for tcpip-forward requests, got %d of 2", received) + } + } + + // Verify relay knows about both ports. + if !relay.hasPort(8080) { + t.Fatal("relay should have port 8080") + } + if !relay.hasPort(3000) { + t.Fatal("relay should have port 3000") + } +} + +func TestRemovePortSendsCancelTcpipForward(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Add a port. + session.AddPort(8080, "test-token") + relay.waitForPortForward(t, 8080) + + // Remove the port. + session.RemovePort(8080, "test-token") + + // Wait for cancel-tcpip-forward. + timeout := time.After(5 * time.Second) + for { + select { + case info := <-relay.portReqs: + if info.reqType == "cancel-tcpip-forward" && info.port == 8080 { + return // Success + } + case <-timeout: + t.Fatal("timeout waiting for cancel-tcpip-forward") + } + } +} + +func TestRemovePortNoOpForUnregisteredPort(t *testing.T) { + session, _, cleanup := setupHostAndRelay(t) + defer cleanup() + + // Remove a port that was never added — should not send anything. + session.RemovePort(9999, "test-token") + + // Give time for any spurious request to arrive. + time.Sleep(200 * time.Millisecond) + + // Verify port list is still empty. + session.portsMu.RLock() + count := len(session.ports) + session.portsMu.RUnlock() + if count != 0 { + t.Fatalf("expected 0 ports, got %d", count) + } +} + +func TestCloseWhileChannelsOpen(t *testing.T) { + session, relay, cleanup := setupHostAndRelay(t) + _ = cleanup // We'll close manually. + + // Add a port and open a channel. + session.AddPort(8080, "test-token") + relay.waitForPortForward(t, 8080) + + // The channel open will likely fail because there's no local listener, + // but that's fine — we're testing close safety. + relay.openForwardedTCPIP(8080) + + time.Sleep(200 * time.Millisecond) + + // Close the host session. This should not panic. + session.Close() + relay.conn.Close() + + // Give goroutines time to clean up. + time.Sleep(100 * time.Millisecond) +} + +func TestAddPortDeduplicates(t *testing.T) { + hostKey := newTestHostKey() + logger := newTestLogger() + hostEnd, relayEnd := net.Pipe() + defer hostEnd.Close() + defer relayEnd.Close() + session := NewHostSSHSession(hostEnd, hostKey, logger, "test-token", HostWebSocketSubProtocolV2) + + session.AddPort(8080, "test-token") + session.AddPort(8080, "test-token") // Adding same port again — should be deduplicated. + + // Verify it was added only once. + session.portsMu.RLock() + count := 0 + for _, p := range session.ports { + if p == 8080 { + count++ + } + } + session.portsMu.RUnlock() + + if count != 1 { + t.Fatalf("expected port 8080 to appear once, got %d", count) + } +} + +func TestAddPortWhenNotConnected(t *testing.T) { + hostKey := newTestHostKey() + logger := newTestLogger() + hostEnd, relayEnd := net.Pipe() + defer hostEnd.Close() + defer relayEnd.Close() + session := NewHostSSHSession(hostEnd, hostKey, logger, "test-token", HostWebSocketSubProtocolV2) + + // Session is not connected, but AddPort should still add to the local list. + session.AddPort(8080, "test-token") + + // Verify port was added. + session.portsMu.RLock() + found := false + for _, p := range session.ports { + if p == 8080 { + found = true + } + } + session.portsMu.RUnlock() + + if !found { + t.Fatal("expected port 8080 to be added") + } +} + +// ============================================================ +// V1 tests +// ============================================================ + +// mockV1Relay holds the relay side of an SSH connection for V1 tests. +// In V1, the relay rejects all global requests and opens client-ssh-session-stream +// channels to simulate client connections. +type mockV1Relay struct { + conn ssh.Conn +} + +// setupHostAndRelayV1 creates a HostSSHSession with V1 protocol connected via TCP +// loopback with a V1 mock relay. +func setupHostAndRelayV1(t *testing.T) (*HostSSHSession, *mockV1Relay, func()) { + t.Helper() + + hostEnd, relayEnd := tcpConnPair(t) + hostKey := newTestHostKey() + logger := newTestLogger() + + session := NewHostSSHSession(hostEnd, hostKey, logger, "", HostWebSocketSubProtocol) + + // Relay side: SSH server + relayConfig := &ssh.ServerConfig{NoClientAuth: true} + privateKey, err := ssh.ParsePrivateKey([]byte(testRSAPrivateKey)) + if err != nil { + t.Fatalf("failed to parse private key: %v", err) + } + relayConfig.AddHostKey(privateKey) + + relay := &mockV1Relay{} + + var relayConn ssh.Conn + var relayErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var reqs <-chan *ssh.Request + var chans <-chan ssh.NewChannel + relayConn, chans, reqs, relayErr = ssh.NewServerConn(relayEnd, relayConfig) + if relayErr == nil { + // V1 relay rejects all global requests. + go func() { + for req := range reqs { + req.Reply(false, nil) + } + }() + // Drain incoming channels. + go func() { + for ch := range chans { + ch.Reject(ssh.Prohibited, "not supported") + } + }() + } + }() + + ctx := context.Background() + if err := session.Connect(ctx); err != nil { + t.Fatalf("failed to connect host session: %v", err) + } + + wg.Wait() + if relayErr != nil { + t.Fatalf("relay SSH handshake failed: %v", relayErr) + } + + relay.conn = relayConn + + cleanup := func() { + session.Close() + relayConn.Close() + } + + return session, relay, cleanup +} + +// openClientSession opens a client-ssh-session-stream channel on the relay +// and performs a nested SSH client handshake, returning the *ssh.Client. +func (r *mockV1Relay) openClientSession(t *testing.T) *ssh.Client { + t.Helper() + + channel, reqs, err := r.conn.OpenChannel("client-ssh-session-stream", nil) + if err != nil { + t.Fatalf("failed to open client-ssh-session-stream: %v", err) + } + go ssh.DiscardRequests(reqs) + + // Wrap channel as net.Conn for nested SSH handshake. + conn := &testChannelConn{Channel: channel} + + clientConfig := &ssh.ClientConfig{ + User: "tunnel", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + } + + sshConn, chans, globalReqs, err := ssh.NewClientConn(conn, "", clientConfig) + if err != nil { + t.Fatalf("nested SSH client handshake failed: %v", err) + } + + return ssh.NewClient(sshConn, chans, globalReqs) +} + +// testChannelConn wraps an ssh.Channel as a net.Conn for test use. +type testChannelConn struct { + ssh.Channel +} + +func (c *testChannelConn) LocalAddr() net.Addr { return testDummyAddr{} } +func (c *testChannelConn) RemoteAddr() net.Addr { return testDummyAddr{} } +func (c *testChannelConn) SetDeadline(t time.Time) error { return nil } +func (c *testChannelConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testChannelConn) SetWriteDeadline(t time.Time) error { return nil } + +type testDummyAddr struct{} + +func (testDummyAddr) Network() string { return "test" } +func (testDummyAddr) String() string { return "test" } + +func TestV1HostSessionAcceptsClientStream(t *testing.T) { + _, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // Open a client-ssh-session-stream and verify nested SSH handshake completes. + client := relay.openClientSession(t) + defer client.Close() + + // If we got here, the nested SSH handshake succeeded. +} + +func TestV1HostSessionRejectsUnknownChannel(t *testing.T) { + _, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // Try to open an unknown channel type. + _, _, err := relay.conn.OpenChannel("unknown-type", nil) + if err == nil { + t.Fatal("expected error for unknown channel type, got nil") + } + + if openErr, ok := err.(*ssh.OpenChannelError); ok { + if openErr.Reason != ssh.UnknownChannelType { + t.Fatalf("expected UnknownChannelType, got %v", openErr.Reason) + } + } else { + t.Fatalf("expected *ssh.OpenChannelError, got %T: %v", err, err) + } +} + +func TestV1PortForwardToClient(t *testing.T) { + session, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // First add a port, then connect a client. + session.AddPort(8080, "") + + client := relay.openClientSession(t) + defer client.Close() + + // The host should send tcpip-forward to this client for port 8080. + // Listen for it using client.HandleChannelOpen or check global requests. + // In V1, the host sends tcpip-forward as a global request to the client. + // The ssh.Client handles global requests; we need to check them. + // Wait for the tcpip-forward request to arrive. + time.Sleep(500 * time.Millisecond) + + // Verify the port was registered. + if !session.HasPort(8080) { + t.Fatal("expected port 8080 to be registered") + } +} + +func TestV1AddPortNotifiesExistingClients(t *testing.T) { + session, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // Connect a client first. + client := relay.openClientSession(t) + defer client.Close() + + // Give time for the client to be registered. + time.Sleep(300 * time.Millisecond) + + // Now add a port — should notify the existing client. + session.AddPort(9090, "") + + // Wait for the notification to propagate. + time.Sleep(500 * time.Millisecond) + + // Verify the port is in the session. + if !session.HasPort(9090) { + t.Fatal("expected port 9090 to be registered") + } +} + +func TestV1RemovePortNotifiesClients(t *testing.T) { + session, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // Connect a client. + client := relay.openClientSession(t) + defer client.Close() + + // Give time for registration. + time.Sleep(300 * time.Millisecond) + + // Add then remove a port. + session.AddPort(7070, "") + time.Sleep(200 * time.Millisecond) + + session.RemovePort(7070, "") + time.Sleep(200 * time.Millisecond) + + // Verify the port was removed. + if session.HasPort(7070) { + t.Fatal("expected port 7070 to be removed") + } +} + +func TestV1MultipleClients(t *testing.T) { + session, relay, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + // Add a port first. + session.AddPort(5050, "") + + // Connect multiple clients. + client1 := relay.openClientSession(t) + defer client1.Close() + + client2 := relay.openClientSession(t) + defer client2.Close() + + // Give time for both to register and receive port forwards. + time.Sleep(500 * time.Millisecond) + + // Verify the session has the expected number of clients. + session.clientsMu.RLock() + clientCount := len(session.clients) + session.clientsMu.RUnlock() + + if clientCount != 2 { + t.Fatalf("expected 2 clients, got %d", clientCount) + } +} + +func TestV1ConnectionProtocol(t *testing.T) { + session, _, cleanup := setupHostAndRelayV1(t) + defer cleanup() + + if session.ConnectionProtocol() != HostWebSocketSubProtocol { + t.Fatalf("expected %q, got %q", HostWebSocketSubProtocol, session.ConnectionProtocol()) + } +} + +func TestV2ConnectionProtocol(t *testing.T) { + session, _, cleanup := setupHostAndRelay(t) + defer cleanup() + + if session.ConnectionProtocol() != HostWebSocketSubProtocolV2 { + t.Fatalf("expected %q, got %q", HostWebSocketSubProtocolV2, session.ConnectionProtocol()) + } +} diff --git a/go/tunnels/ssh/messages/port_forward_cancel_request.go b/go/tunnels/ssh/messages/port_forward_cancel_request.go new file mode 100644 index 00000000..bb564f2c --- /dev/null +++ b/go/tunnels/ssh/messages/port_forward_cancel_request.go @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package messages + +import ( + "bytes" + "fmt" + "io" +) + +const ( + PortForwardCancelRequestType = "cancel-tcpip-forward" +) + +type PortForwardCancelRequest struct { + addressToBind string + port uint32 +} + +func NewPortForwardCancelRequest(addressToBind string, port uint32) *PortForwardCancelRequest { + return &PortForwardCancelRequest{ + addressToBind: addressToBind, + port: port, + } +} + +func (pfcr *PortForwardCancelRequest) Port() uint32 { + return pfcr.port +} + +func (pfcr *PortForwardCancelRequest) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := writeString(buf, pfcr.addressToBind); err != nil { + return nil, fmt.Errorf("error writing address to bind: %w", err) + } + if err := writeUint32(buf, pfcr.port); err != nil { + return nil, fmt.Errorf("error writing port: %w", err) + } + return buf.Bytes(), nil +} + +func (pfcr *PortForwardCancelRequest) Unmarshal(buf io.Reader) error { + addressToBind, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading address to bind: %w", err) + } + port, err := readUint32(buf) + if err != nil { + return fmt.Errorf("error reading port: %w", err) + } + pfcr.addressToBind = addressToBind + pfcr.port = port + return nil +} diff --git a/go/tunnels/ssh/messages/port_relay_connect_request.go b/go/tunnels/ssh/messages/port_relay_connect_request.go new file mode 100644 index 00000000..5e0fad59 --- /dev/null +++ b/go/tunnels/ssh/messages/port_relay_connect_request.go @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package messages + +import ( + "bytes" + "fmt" + "io" +) + +// PortRelayConnectRequest is the V2 forwarded-tcpip channel extra data sent +// by the relay when opening a channel to the host. It extends the standard +// forwarded-tcpip fields with AccessToken and IsE2EEncryptionRequested. +// +// Wire format: +// | host (string) | port (uint32) | originatorIP (string) | originatorPort (uint32) | +// | accessToken (string) | isE2EEncryptionRequested (bool) | +type PortRelayConnectRequest struct { + Host string + Port uint32 + OriginatorIP string + OriginatorPort uint32 + AccessToken string + IsE2EEncryptionRequested bool +} + +// Marshal serializes the request to wire format. +func (r *PortRelayConnectRequest) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := writeString(buf, r.Host); err != nil { + return nil, fmt.Errorf("error writing host: %w", err) + } + if err := writeUint32(buf, r.Port); err != nil { + return nil, fmt.Errorf("error writing port: %w", err) + } + if err := writeString(buf, r.OriginatorIP); err != nil { + return nil, fmt.Errorf("error writing originator IP: %w", err) + } + if err := writeUint32(buf, r.OriginatorPort); err != nil { + return nil, fmt.Errorf("error writing originator port: %w", err) + } + if err := writeString(buf, r.AccessToken); err != nil { + return nil, fmt.Errorf("error writing access token: %w", err) + } + if err := writeBool(buf, r.IsE2EEncryptionRequested); err != nil { + return nil, fmt.Errorf("error writing isE2EEncryptionRequested: %w", err) + } + return buf.Bytes(), nil +} + +// Unmarshal deserializes the request from wire format. +func (r *PortRelayConnectRequest) Unmarshal(buf io.Reader) error { + host, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading host: %w", err) + } + port, err := readUint32(buf) + if err != nil { + return fmt.Errorf("error reading port: %w", err) + } + originatorIP, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading originator IP: %w", err) + } + originatorPort, err := readUint32(buf) + if err != nil { + return fmt.Errorf("error reading originator port: %w", err) + } + accessToken, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading access token: %w", err) + } + isE2EEncryptionRequested, err := readBool(buf) + if err != nil { + return fmt.Errorf("error reading isE2EEncryptionRequested: %w", err) + } + r.Host = host + r.Port = port + r.OriginatorIP = originatorIP + r.OriginatorPort = originatorPort + r.AccessToken = accessToken + r.IsE2EEncryptionRequested = isE2EEncryptionRequested + return nil +} diff --git a/go/tunnels/ssh/messages/port_relay_request.go b/go/tunnels/ssh/messages/port_relay_request.go new file mode 100644 index 00000000..a4685ffb --- /dev/null +++ b/go/tunnels/ssh/messages/port_relay_request.go @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package messages + +import ( + "bytes" + "fmt" + "io" +) + +// PortRelayRequest is the V2 tcpip-forward request payload sent by the host +// to the relay. It extends the standard tcpip-forward with an AccessToken field. +// +// Wire format: | addressToBind (string) | port (uint32) | accessToken (string) | +type PortRelayRequest struct { + addressToBind string + port uint32 + accessToken string +} + +// NewPortRelayRequest creates a new PortRelayRequest. +func NewPortRelayRequest(addressToBind string, port uint32, accessToken string) *PortRelayRequest { + return &PortRelayRequest{ + addressToBind: addressToBind, + port: port, + accessToken: accessToken, + } +} + +// Port returns the port number. +func (r *PortRelayRequest) Port() uint32 { + return r.port +} + +// AccessToken returns the access token. +func (r *PortRelayRequest) AccessToken() string { + return r.accessToken +} + +// Marshal serializes the request to wire format. +func (r *PortRelayRequest) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := writeString(buf, r.addressToBind); err != nil { + return nil, fmt.Errorf("error writing address to bind: %w", err) + } + if err := writeUint32(buf, r.port); err != nil { + return nil, fmt.Errorf("error writing port: %w", err) + } + if err := writeString(buf, r.accessToken); err != nil { + return nil, fmt.Errorf("error writing access token: %w", err) + } + return buf.Bytes(), nil +} + +// Unmarshal deserializes the request from wire format. +func (r *PortRelayRequest) Unmarshal(buf io.Reader) error { + addressToBind, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading address to bind: %w", err) + } + port, err := readUint32(buf) + if err != nil { + return fmt.Errorf("error reading port: %w", err) + } + accessToken, err := readString(buf) + if err != nil { + return fmt.Errorf("error reading access token: %w", err) + } + r.addressToBind = addressToBind + r.port = port + r.accessToken = accessToken + return nil +} diff --git a/go/tunnels/ssh/messages/readers.go b/go/tunnels/ssh/messages/readers.go index b36281c7..92220b17 100644 --- a/go/tunnels/ssh/messages/readers.go +++ b/go/tunnels/ssh/messages/readers.go @@ -15,6 +15,14 @@ func readUint32(buf io.Reader) (i uint32, err error) { return i, nil } +func readBool(buf io.Reader) (bool, error) { + var b [1]byte + if _, err := io.ReadFull(buf, b[:]); err != nil { + return false, err + } + return b[0] != 0, nil +} + func readString(buf io.Reader) (s string, err error) { var l uint32 if l, err = readUint32(buf); err != nil { diff --git a/go/tunnels/test/mock_relay_host.go b/go/tunnels/test/mock_relay_host.go new file mode 100644 index 00000000..27e4f1be --- /dev/null +++ b/go/tunnels/test/mock_relay_host.go @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelstest + +import ( + "bytes" + "fmt" + "net" + "sync" + "time" + + "github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages" + "golang.org/x/crypto/ssh" +) + +// MockRelayForHost simulates a V2 relay server using net.Pipe() for host unit tests. +// It connects as an SSH server to the host's SSH client on the relay end of the pipe. +// In V2, the relay handles tcpip-forward/cancel-tcpip-forward global requests and +// opens forwarded-tcpip channels directly to the host (no nested SSH). +type MockRelayForHost struct { + sshConn ssh.Conn + + mu sync.Mutex + ports map[uint16]struct{} +} + +// NewMockRelayForHost creates a new MockRelayForHost that connects as an SSH server +// to the host on the given net.Conn (typically the relay end of a net.Pipe()). +func NewMockRelayForHost(relayEnd net.Conn) (*MockRelayForHost, error) { + sshConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) + if err != nil { + return nil, fmt.Errorf("error parsing private key: %w", err) + } + sshConfig.AddHostKey(privateKey) + + serverConn, chans, reqs, err := ssh.NewServerConn(relayEnd, sshConfig) + if err != nil { + return nil, fmt.Errorf("error creating SSH server connection: %w", err) + } + + m := &MockRelayForHost{ + sshConn: serverConn, + ports: make(map[uint16]struct{}), + } + + go m.handleGlobalRequests(reqs) + + // Drain incoming channels (reject all). + go func() { + for ch := range chans { + ch.Reject(ssh.Prohibited, "not supported") + } + }() + + return m, nil +} + +// handleGlobalRequests handles tcpip-forward and cancel-tcpip-forward from the host. +func (m *MockRelayForHost) handleGlobalRequests(reqs <-chan *ssh.Request) { + for req := range reqs { + switch req.Type { + case "tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + m.mu.Lock() + m.ports[uint16(prr.Port())] = struct{}{} + m.mu.Unlock() + req.Reply(true, nil) + + case "cancel-tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + m.mu.Lock() + delete(m.ports, uint16(prr.Port())) + m.mu.Unlock() + req.Reply(true, nil) + + default: + req.Reply(false, nil) + } + } +} + +// SimulateClientConnection opens a forwarded-tcpip channel to the host with +// V2 extra data and returns the channel as a net.Conn for test client use. +func (m *MockRelayForHost) SimulateClientConnection(port uint16) (net.Conn, error) { + data := &messages.PortRelayConnectRequest{ + Host: "127.0.0.1", + Port: uint32(port), + OriginatorIP: "127.0.0.1", + OriginatorPort: 0, + AccessToken: "", + IsE2EEncryptionRequested: false, + } + extraData, err := data.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshaling V2 channel data: %w", err) + } + + channel, reqs, err := m.sshConn.OpenChannel("forwarded-tcpip", extraData) + if err != nil { + return nil, fmt.Errorf("error opening forwarded-tcpip channel: %w", err) + } + go ssh.DiscardRequests(reqs) + + return &channelNetConn{Channel: channel}, nil +} + +// HasPort reports whether the given port has been registered via tcpip-forward. +func (m *MockRelayForHost) HasPort(port uint16) bool { + m.mu.Lock() + defer m.mu.Unlock() + _, ok := m.ports[port] + return ok +} + +// Close cleanly shuts down the mock relay. +func (m *MockRelayForHost) Close() error { + if m.sshConn != nil { + m.sshConn.Close() + } + return nil +} + +// channelNetConn wraps an ssh.Channel as a net.Conn for test use. +type channelNetConn struct { + ssh.Channel +} + +func (c *channelNetConn) LocalAddr() net.Addr { return dummyTestAddr{} } +func (c *channelNetConn) RemoteAddr() net.Addr { return dummyTestAddr{} } +func (c *channelNetConn) SetDeadline(t time.Time) error { return nil } +func (c *channelNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c *channelNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type dummyTestAddr struct{} + +func (dummyTestAddr) Network() string { return "tunnel-test" } +func (dummyTestAddr) String() string { return "tunnel-test" } + +// MockRelayForHostV1 simulates a V1 relay server using net.Pipe() for host unit tests. +// In V1, the relay rejects all global requests (no tcpip-forward handling) and +// simulates client connections by opening client-ssh-session-stream channels +// with a nested SSH client handshake inside. +type MockRelayForHostV1 struct { + sshConn ssh.Conn +} + +// NewMockRelayForHostV1 creates a new V1 mock relay that connects as an SSH server +// to the host on the given net.Conn. +func NewMockRelayForHostV1(relayEnd net.Conn) (*MockRelayForHostV1, error) { + sshConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) + if err != nil { + return nil, fmt.Errorf("error parsing private key: %w", err) + } + sshConfig.AddHostKey(privateKey) + + serverConn, chans, reqs, err := ssh.NewServerConn(relayEnd, sshConfig) + if err != nil { + return nil, fmt.Errorf("error creating SSH server connection: %w", err) + } + + m := &MockRelayForHostV1{ + sshConn: serverConn, + } + + // V1 relay rejects all global requests. + go func() { + for req := range reqs { + req.Reply(false, nil) + } + }() + + // Drain incoming channels (reject all). + go func() { + for ch := range chans { + ch.Reject(ssh.Prohibited, "not supported") + } + }() + + return m, nil +} + +// SimulateClientConnection opens a client-ssh-session-stream channel to the host, +// performs a nested SSH client handshake inside it, and returns the *ssh.Client +// so tests can receive tcpip-forward requests and open channels on the nested SSH. +func (m *MockRelayForHostV1) SimulateClientConnection() (*ssh.Client, error) { + channel, reqs, err := m.sshConn.OpenChannel("client-ssh-session-stream", nil) + if err != nil { + return nil, fmt.Errorf("error opening client-ssh-session-stream channel: %w", err) + } + go ssh.DiscardRequests(reqs) + + // Wrap channel as net.Conn for the nested SSH client handshake. + conn := &channelNetConn{Channel: channel} + + // Perform nested SSH client handshake inside the channel. + clientConfig := &ssh.ClientConfig{ + User: "tunnel", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + } + + sshConn, chans, globalReqs, err := ssh.NewClientConn(conn, "", clientConfig) + if err != nil { + channel.Close() + return nil, fmt.Errorf("nested SSH client handshake failed: %w", err) + } + + client := ssh.NewClient(sshConn, chans, globalReqs) + return client, nil +} + +// Close cleanly shuts down the V1 mock relay. +func (m *MockRelayForHostV1) Close() error { + if m.sshConn != nil { + m.sshConn.Close() + } + return nil +} diff --git a/go/tunnels/test/relay_host_server.go b/go/tunnels/test/relay_host_server.go new file mode 100644 index 00000000..0eba8810 --- /dev/null +++ b/go/tunnels/test/relay_host_server.go @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnelstest + +import ( + "bytes" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "time" + + "github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages" + "github.com/gorilla/websocket" + "golang.org/x/crypto/ssh" +) + +// RelayHostServer is a WebSocket-level mock relay server for host tests. +// It validates the full WebSocket upgrade + SSH handshake, handles +// tcpip-forward/cancel-tcpip-forward global requests (V2), and can simulate +// client connections by opening forwarded-tcpip channels (V2) or +// client-ssh-session-stream channels (V1). +type RelayHostServer struct { + httpServer *httptest.Server + accessToken string + forceV1 bool + errc chan error + + mu sync.Mutex + sshConn ssh.Conn + ports map[uint16]struct{} + connected chan struct{} + negotiatedProtocol string +} + +// RelayHostServerOption is a functional option for configuring RelayHostServer. +type RelayHostServerOption func(*RelayHostServer) + +// WithHostAccessToken configures the expected access token for the relay. +func WithHostAccessToken(token string) RelayHostServerOption { + return func(s *RelayHostServer) { + s.accessToken = token + } +} + +// WithProtocolV1Only forces the relay to negotiate V1 even if V2 is offered. +func WithProtocolV1Only() RelayHostServerOption { + return func(s *RelayHostServer) { + s.forceV1 = true + } +} + +// NewRelayHostServer creates a new WebSocket-level mock relay server for host tests. +func NewRelayHostServer(opts ...RelayHostServerOption) (*RelayHostServer, error) { + server := &RelayHostServer{ + errc: make(chan error, 1), + connected: make(chan struct{}), + ports: make(map[uint16]struct{}), + } + + for _, opt := range opts { + opt(server) + } + + server.httpServer = httptest.NewServer(http.HandlerFunc(server.handleConnection)) + + return server, nil +} + +// URL returns the WebSocket URL for the host to connect to. +func (s *RelayHostServer) URL() string { + return "ws" + s.httpServer.URL[4:] // convert http:// to ws:// +} + +// NegotiatedProtocol returns the subprotocol selected during the WebSocket handshake. +func (s *RelayHostServer) NegotiatedProtocol() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.negotiatedProtocol +} + +// SimulateClientConnection opens a forwarded-tcpip channel to the host with +// V2 extra data and returns the channel as a net.Conn. +func (s *RelayHostServer) SimulateClientConnection(port uint16) (net.Conn, error) { + s.mu.Lock() + conn := s.sshConn + s.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("relay not connected") + } + + data := &messages.PortRelayConnectRequest{ + Host: "127.0.0.1", + Port: uint32(port), + OriginatorIP: "127.0.0.1", + OriginatorPort: 0, + AccessToken: "", + IsE2EEncryptionRequested: false, + } + extraData, err := data.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshaling V2 channel data: %w", err) + } + + channel, reqs, err := conn.OpenChannel("forwarded-tcpip", extraData) + if err != nil { + return nil, fmt.Errorf("error opening forwarded-tcpip channel: %w", err) + } + go ssh.DiscardRequests(reqs) + + return &channelNetConn{Channel: channel}, nil +} + +// SimulateClientConnectionV1 opens a client-ssh-session-stream channel (V1) +// and performs a nested SSH client handshake. Returns the *ssh.Client. +func (s *RelayHostServer) SimulateClientConnectionV1() (*ssh.Client, error) { + s.mu.Lock() + conn := s.sshConn + s.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("relay not connected") + } + + channel, reqs, err := conn.OpenChannel("client-ssh-session-stream", nil) + if err != nil { + return nil, fmt.Errorf("error opening client-ssh-session-stream channel: %w", err) + } + go ssh.DiscardRequests(reqs) + + // Wrap as net.Conn for nested SSH handshake. + chanConn := &channelNetConn{Channel: channel} + + clientConfig := &ssh.ClientConfig{ + User: "tunnel", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + } + + sshConn, chans, globalReqs, err := ssh.NewClientConn(chanConn, "", clientConfig) + if err != nil { + channel.Close() + return nil, fmt.Errorf("nested SSH client handshake failed: %w", err) + } + + return ssh.NewClient(sshConn, chans, globalReqs), nil +} + +// HasPort reports whether the given port has been registered via tcpip-forward. +func (s *RelayHostServer) HasPort(port uint16) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.ports[port] + return ok +} + +// Err returns the error channel for the relay server. +func (s *RelayHostServer) Err() <-chan error { + return s.errc +} + +// Close shuts down the server. +func (s *RelayHostServer) Close() error { + s.mu.Lock() + conn := s.sshConn + s.mu.Unlock() + + if conn != nil { + conn.Close() + } + s.httpServer.Close() + return nil +} + +func (s *RelayHostServer) sendError(err error) { + select { + case s.errc <- err: + default: + } +} + +var hostUpgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, +} + +func (s *RelayHostServer) handleConnection(w http.ResponseWriter, r *http.Request) { + // Validate access token if configured. + if s.accessToken != "" { + if r.Header.Get("Authorization") != s.accessToken { + s.sendError(fmt.Errorf("invalid access token")) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + + // Select subprotocol from offered list. + protocols := websocket.Subprotocols(r) + selectedProtocol := "" + if s.forceV1 { + // Force V1: only accept tunnel-relay-host. + for _, p := range protocols { + if p == "tunnel-relay-host" { + selectedProtocol = p + break + } + } + } else { + // Prefer V2, fall back to V1. + for _, p := range protocols { + if p == "tunnel-relay-host-v2-dev" { + selectedProtocol = p + break + } + } + if selectedProtocol == "" { + for _, p := range protocols { + if p == "tunnel-relay-host" { + selectedProtocol = p + break + } + } + } + } + + if selectedProtocol == "" { + s.sendError(fmt.Errorf("no supported subprotocol offered: %v", protocols)) + http.Error(w, "bad subprotocol", http.StatusBadRequest) + return + } + + // Upgrade to WebSocket. + respHeader := http.Header{} + respHeader.Set("Sec-WebSocket-Protocol", selectedProtocol) + c, err := hostUpgrader.Upgrade(w, r, respHeader) + if err != nil { + s.sendError(fmt.Errorf("error upgrading to websocket: %w", err)) + return + } + + socketConn := newSocketConn(c) + + // Connect as SSH server to the host's SSH client. + sshConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) + if err != nil { + s.sendError(fmt.Errorf("error parsing private key: %w", err)) + return + } + sshConfig.AddHostKey(privateKey) + + serverConn, _, reqs, err := ssh.NewServerConn(socketConn, sshConfig) + if err != nil { + s.sendError(fmt.Errorf("error creating SSH server conn: %w", err)) + return + } + + // Handle global requests based on protocol. + if selectedProtocol == "tunnel-relay-host" { + // V1: reject all global requests (relay doesn't handle tcpip-forward). + go func() { + for req := range reqs { + req.Reply(false, nil) + } + }() + } else { + // V2: handle tcpip-forward/cancel-tcpip-forward. + go s.handleGlobalRequests(reqs) + } + + s.mu.Lock() + s.sshConn = serverConn + s.negotiatedProtocol = selectedProtocol + // Signal that the connection is established. + select { + case <-s.connected: + // Already closed (reconnect scenario) — make a new channel. + s.connected = make(chan struct{}) + default: + } + close(s.connected) + s.mu.Unlock() + + // Block until connection closes. + serverConn.Wait() +} + +// handleGlobalRequests handles tcpip-forward and cancel-tcpip-forward from the host. +func (s *RelayHostServer) handleGlobalRequests(reqs <-chan *ssh.Request) { + for req := range reqs { + switch req.Type { + case "tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + s.mu.Lock() + s.ports[uint16(prr.Port())] = struct{}{} + s.mu.Unlock() + req.Reply(true, nil) + + case "cancel-tcpip-forward": + var prr messages.PortRelayRequest + if err := prr.Unmarshal(bytes.NewReader(req.Payload)); err != nil { + req.Reply(false, nil) + continue + } + s.mu.Lock() + delete(s.ports, uint16(prr.Port())) + s.mu.Unlock() + req.Reply(true, nil) + + default: + req.Reply(false, nil) + } + } +} + +// WaitForConnection waits for the host to connect to the relay server. +func (s *RelayHostServer) WaitForConnection(timeout time.Duration) error { + s.mu.Lock() + ch := s.connected + s.mu.Unlock() + + select { + case <-ch: + return nil + case err := <-s.errc: + return err + case <-time.After(timeout): + return fmt.Errorf("timeout waiting for host connection") + } +} diff --git a/go/tunnels/tunnel_error.go b/go/tunnels/tunnel_error.go new file mode 100644 index 00000000..7f505e88 --- /dev/null +++ b/go/tunnels/tunnel_error.go @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tunnels + +import "fmt" + +// TunnelError represents an error response from the tunnel service +// with an HTTP status code. +type TunnelError struct { + StatusCode int + Message string +} + +func (e *TunnelError) Error() string { + return fmt.Sprintf("tunnel service error (status %d): %s", e.StatusCode, e.Message) +} diff --git a/prd.json b/prd.json new file mode 100644 index 00000000..e29b8bbd --- /dev/null +++ b/prd.json @@ -0,0 +1,436 @@ +{ + "project": "dev-tunnels", + "branchName": "ralph/go-host-e2e-tests", + "description": "Go SDK Host E2E Test Suite - Comprehensive end-to-end validation of the Go host implementation covering all public API methods, error paths, data flows, concurrency, reconnection, and edge cases using in-process mock infrastructure", + "userStories": [ + { + "id": "US-000", + "title": "Verify Go development environment", + "description": "As a developer, I need a working Go development environment so that I can build and run the E2E test suite.", + "acceptanceCriteria": [ + "Go is installed and available on PATH (expected at /opt/homebrew/bin/go on macOS ARM64, or install via `brew install go` if missing)", + "Go version is 1.17 or later (project go.mod specifies go 1.17; current environment has Go 1.26.0)", + "Module dependencies are downloaded: run `go mod download` from the project root", + "Required dependencies verified present: github.com/gorilla/websocket v1.4.2, golang.org/x/crypto v0.23.0", + "`go vet ./go/tunnels/...` passes (validates existing codebase compiles and is vet-clean)", + "`go test -short -v ./go/tunnels/ssh/...` passes (validates existing SSH tests work — confirms test toolchain is functional)", + "The -race flag works: `go test -race -short -run TestChannelConn ./go/tunnels/ssh/` passes", + "No additional system dependencies required (no C compiler, no CGO, no external services)" + ], + "priority": 1, + "passes": true, + "notes": "" + }, + { + "id": "US-001", + "title": "Create shared E2E test infrastructure", + "description": "As a developer, I need reusable test helpers so that all 20 E2E scenarios share common setup patterns and avoid code duplication.", + "acceptanceCriteria": [ + "`e2eMockAPI` struct defined in `go/tunnels/host_e2e_test.go` with fields: server *httptest.Server, manager *Manager, deleteEndpointCalls int32 (atomic), createPortCalls int32 (atomic), remotePorts atomic.Value (stores []TunnelPort), relayURI atomic.Value (stores string), unauthorizedOnce int32 (atomic flag for 401 simulation)", + "`e2eMockAPI` HTTP handler routes: PUT .../endpoints/... returns TunnelEndpoint with HostRelayURI from relayURI, DELETE .../endpoints/... increments deleteEndpointCalls, GET .../tunnels/... returns Tunnel with remotePorts, PUT .../ports/... increments createPortCalls, DELETE .../ports/... returns 200, GET .../tunnels/.../ports/... returns remotePorts", + "`connectClientViaRelay(t, relay)` helper: calls relay.SimulateClientArrival(), performs nested SSH handshake (ssh.NewClientConn with InsecureIgnoreHostKey), drains incoming channels in background, returns ssh.Conn + <-chan *ssh.Request", + "`connectClientWithKeyCapture(t, relay)` helper: same as connectClientViaRelay but uses custom HostKeyCallback to capture the received host public key, returns ssh.Conn + <-chan *ssh.Request + ssh.PublicKey", + "`startEchoServerE2E(t)` helper: listens on tcp4 127.0.0.1:0, accept loop with io.Copy(conn, conn), t.Cleanup closes listener, returns net.Listener + uint16 port", + "`openForwardedChannel(t, sshConn, port)` helper: marshals forwardedTCPIPData{Host: \"127.0.0.1\", Port: uint32(port)}, opens \"forwarded-tcpip\" channel, discards channel requests in background, returns ssh.Channel", + "`drainAndReplyRequests(globalReqs, collected)` helper: background goroutine that reads from globalReqs, sends to collected channel (buffered), replies true nil, runs until globalReqs is closed", + "All helpers defined in go/tunnels/host_e2e_test.go", + "`go vet` and `go build` pass" + ], + "priority": 2, + "passes": true, + "notes": "CRITICAL: Never use net.Pipe() for SSH connections — use RelayHostServer (WebSocket) or tcpConnPair() (TCP loopback). net.Pipe() is unbuffered and deadlocks SSH handshakes. Also: AddPort/RemovePort send SSH global requests with wantReply=true. Without drainAndReplyRequests consuming and replying to these requests, the calls will deadlock." + }, + { + "id": "US-002", + "title": "Test host lifecycle", + "description": "As a developer, I want to verify the complete host lifecycle — construction, connection, status, graceful close, endpoint cleanup, and idempotent close.", + "acceptanceCriteria": [ + "`TestE2E_HostLifecycle` test function in host_e2e_test.go", + "Creates RelayHostServer with WithHostAccessToken(\"Tunnel test-token\")", + "Creates e2eMockAPI pointing to relay server URL", + "NewHost(logger, mgr) succeeds", + "Asserts host.ConnectionStatus() == ConnectionStatusNone", + "Asserts host.HostPublicKeyBase64() is non-empty and base64-decodable", + "host.Connect(ctx, tunnel) succeeds", + "relay.WaitForConnection(5s) succeeds", + "Asserts host.ConnectionStatus() == ConnectionStatusConnected", + "host.Close() succeeds", + "Asserts host.ConnectionStatus() == ConnectionStatusDisconnected", + "Asserts api.deleteEndpointCalls == 1", + "Second host.Close() returns nil (idempotent)", + "api.deleteEndpointCalls still 1", + "Test passes with `go test -v -run TestE2E_HostLifecycle ./go/tunnels/`" + ], + "priority": 3, + "passes": true, + "notes": "" + }, + { + "id": "US-003", + "title": "Test error handling", + "description": "As a developer, I want to verify all error sentinels are returned under the correct conditions via table-driven subtests.", + "acceptanceCriteria": [ + "`TestE2E_ErrorHandling` test function with table-driven subtests", + "Subtest ErrNoManager: NewHost(nil, nil) returns ErrNoManager", + "Subtest ErrAlreadyConnected: connected host, second Connect returns ErrAlreadyConnected", + "Subtest ErrNotConnected_Close: unconnected host, Close() returns ErrNotConnected", + "Subtest ErrNotConnected_Wait: unconnected host, Wait() returns ErrNotConnected", + "Subtest ErrNotConnected_AddPort: unconnected host, AddPort returns ErrNotConnected", + "Subtest ErrNotConnected_RemovePort: unconnected host, RemovePort returns ErrNotConnected", + "Subtest ErrNotConnected_RefreshPorts: unconnected host, RefreshPorts returns ErrNotConnected", + "Subtest ErrPortAlreadyAdded: connected host, port already added, second AddPort returns ErrPortAlreadyAdded", + "Subtest ErrTooManyConnections: host with disconnectReason = 11, Connect returns ErrTooManyConnections", + "Subtest ErrNoHostRelayURI: mock API returns empty HostRelayURI, Connect returns ErrNoHostRelayURI", + "Subtest 409Conflict_Tolerated: mock API returns 409 on port creation, AddPort returns nil", + "All subtests use errors.Is() for sentinel checks", + "Test passes with `go test -v -run TestE2E_ErrorHandling ./go/tunnels/`" + ], + "priority": 4, + "passes": true, + "notes": "" + }, + { + "id": "US-004", + "title": "Test connection status callbacks", + "description": "As a developer, I want to verify the ConnectionStatusChanged callback fires for every status transition in the correct order.", + "acceptanceCriteria": [ + "`TestE2E_ConnectionStatusCallbacks` test function", + "Registers callback that appends ConnectionStatus values to mutex-guarded slice", + "Asserts initial host.ConnectionStatus() == ConnectionStatusNone", + "host.Connect(ctx, tunnel) triggers Connecting, Connected", + "host.Close() triggers Disconnected", + "Asserts transitions == [Connecting, Connected, Disconnected]", + "Exactly 3 transitions (no duplicates, no extras)", + "Test passes with `go test -v -run TestE2E_ConnectionStatusCallbacks ./go/tunnels/`" + ], + "priority": 5, + "passes": true, + "notes": "" + }, + { + "id": "US-005", + "title": "Test port forwarding data flow", + "description": "As a developer, I want to verify the complete data path through the tunnel stack — the most critical E2E scenario.", + "acceptanceCriteria": [ + "`TestE2E_PortForwardingDataFlow` test function", + "Starts TCP echo server via startEchoServerE2E(t)", + "Creates relay, mock API, Host, connects", + "host.AddPort(ctx, &TunnelPort{PortNumber: echoPort}) succeeds", + "connectClientViaRelay(t, relay) returns sshConn, globalReqs", + "Reads tcpip-forward from globalReqs, replies true", + "openForwardedChannel(t, sshConn, echoPort) returns channel", + "Writes \"hello e2e tunnel\" to channel", + "io.ReadFull reads echo response", + "Asserts sent bytes == received bytes", + "Test passes with `go test -v -run TestE2E_PortForwardingDataFlow ./go/tunnels/`" + ], + "priority": 6, + "passes": true, + "notes": "" + }, + { + "id": "US-006", + "title": "Test direct-tcpip and forwarded-tcpip channel types", + "description": "As a developer, I want to verify both SSH channel types work for port forwarding and unregistered ports are rejected.", + "acceptanceCriteria": [ + "`TestE2E_DirectTcpipAndForwardedTcpip` test function", + "Starts echo server, creates Host, connects, adds echo port", + "Connects client, starts drainAndReplyRequests", + "Opens forwarded-tcpip channel, sends \"forwarded-test\", verifies echo", + "Opens direct-tcpip channel to same port, sends \"direct-test\", verifies echo", + "Opens direct-tcpip channel to unregistered port (echoPort + 1000)", + "Asserts channel open fails with *ssh.OpenChannelError, reason ssh.Prohibited", + "Test passes with `go test -v -run TestE2E_DirectTcpipAndForwardedTcpip ./go/tunnels/`" + ], + "priority": 7, + "passes": true, + "notes": "Must use drainAndReplyRequests to avoid wantReply=true deadlock when AddPort sends tcpip-forward to the connected client." + }, + { + "id": "US-007", + "title": "Test multiple ports", + "description": "As a developer, I want to verify multiple ports can be registered and data flows correctly through each independently.", + "acceptanceCriteria": [ + "`TestE2E_MultiplePorts` test function", + "Starts 3 echo servers on 3 different ports", + "Adds all 3 ports via host.AddPort", + "Connects client, receives and replies to 3 tcpip-forward requests", + "For each port: opens forwarded-tcpip channel, sends unique message (port-A, port-B, port-C), verifies echo", + "No cross-contamination between ports", + "Test passes with `go test -v -run TestE2E_MultiplePorts ./go/tunnels/`" + ], + "priority": 8, + "passes": true, + "notes": "" + }, + { + "id": "US-008", + "title": "Test dynamic port management", + "description": "As a developer, I want to verify adding and removing ports while a client is already connected, with real-time SSH notifications.", + "acceptanceCriteria": [ + "`TestE2E_DynamicPortManagement` test function", + "Starts echo server A, creates Host, connects, adds port A", + "Connects client, starts drainAndReplyRequests background goroutine", + "Waits for tcpip-forward for port A in collected channel", + "Verifies data flows through port A", + "Starts echo server B, host.AddPort for port B", + "Waits for tcpip-forward for port B in collected channel", + "Verifies data flows through port B", + "host.RemovePort(ctx, echoPortA)", + "Waits for cancel-tcpip-forward for port A in collected channel", + "Attempts to open forwarded-tcpip to port A — expects rejection with ssh.Prohibited", + "Test passes with `go test -v -run TestE2E_DynamicPortManagement ./go/tunnels/`" + ], + "priority": 9, + "passes": true, + "notes": "CRITICAL: Must use drainAndReplyRequests to avoid wantReply=true deadlock. AddPort/RemovePort send SSH global requests that block until the client replies." + }, + { + "id": "US-009", + "title": "Test port duplicate handling", + "description": "As a developer, I want to verify adding a duplicate port returns the correct error and the port appears exactly once.", + "acceptanceCriteria": [ + "`TestE2E_PortDuplicateHandling` test function", + "Creates Host, connects", + "host.AddPort(ctx, &TunnelPort{PortNumber: 8080}) succeeds", + "host.AddPort(ctx, &TunnelPort{PortNumber: 8080}) returns ErrPortAlreadyAdded", + "Accesses SSH session, verifies ssh.Ports() contains 8080 exactly once", + "Verifies api.createPortCalls == 1", + "Test passes with `go test -v -run TestE2E_PortDuplicateHandling ./go/tunnels/`" + ], + "priority": 10, + "passes": true, + "notes": "" + }, + { + "id": "US-010", + "title": "Test RefreshPorts", + "description": "As a developer, I want to verify RefreshPorts synchronizes local ports with the management service.", + "acceptanceCriteria": [ + "`TestE2E_RefreshPorts` test function", + "Creates Host, connects", + "Accesses SSH session via h.mu.Lock(); ssh := h.ssh; h.mu.Unlock()", + "Manually adds port 5000 to SSH session: ssh.AddPort(5000)", + "Configures mock API: api.remotePorts.Store([]TunnelPort{{PortNumber: 3000}, {PortNumber: 4000}})", + "Calls host.RefreshPorts(ctx)", + "Asserts ssh.HasPort(3000) == true (added from service)", + "Asserts ssh.HasPort(4000) == true (added from service)", + "Asserts ssh.HasPort(5000) == false (removed — not on service)", + "Calls RefreshPorts again — idempotent (no changes)", + "Test passes with `go test -v -run TestE2E_RefreshPorts ./go/tunnels/`" + ], + "priority": 11, + "passes": true, + "notes": "This test accesses internal h.mu and h.ssh fields. This is acceptable because the test is in the same package." + }, + { + "id": "US-011", + "title": "Test large data transfer with SHA256 integrity", + "description": "As a developer, I want to verify 1MB+ data passes through the tunnel with SHA256 integrity verification.", + "acceptanceCriteria": [ + "`TestE2E_LargeDataTransfer` test function", + "Starts echo server, creates Host, connects, adds port", + "Connects client, waits for tcpip-forward, opens forwarded-tcpip channel", + "Generates 1MB payload: payload[i] = byte(i % 256)", + "Computes expectedHash = sha256.Sum256(payload)", + "Goroutine: writes payload to channel, then ch.CloseWrite()", + "Main: io.Copy(received, ch) reads all echoed data", + "Computes actualHash = sha256.Sum256(received.Bytes())", + "Asserts expectedHash == actualHash", + "Asserts received.Len() == 1048576", + "Test passes with `go test -v -run TestE2E_LargeDataTransfer ./go/tunnels/`" + ], + "priority": 12, + "passes": true, + "notes": "Depends on the bidirectional copy fix in handleForwardedTCPIP (waits for both copy directions to drain with CloseWrite() propagation)." + }, + { + "id": "US-012", + "title": "Test bidirectional streaming", + "description": "As a developer, I want to verify sequential request-response pattern works reliably through the tunnel.", + "acceptanceCriteria": [ + "`TestE2E_BidirectionalStreaming` test function", + "Starts echo server, creates Host, connects, adds port", + "Connects client, opens forwarded-tcpip channel", + "Sends 10 messages of increasing size (100, 200, 300, ... 1000 bytes)", + "For each message: generates deterministic content bytes.Repeat([]byte{byte(i)}, size), writes, reads via io.ReadFull, asserts match", + "All 10 messages echo correctly, no ordering issues", + "Test passes with `go test -v -run TestE2E_BidirectionalStreaming ./go/tunnels/`" + ], + "priority": 13, + "passes": true, + "notes": "" + }, + { + "id": "US-013", + "title": "Test multiple concurrent clients", + "description": "As a developer, I want to verify multiple clients connect simultaneously and each independently forwards data.", + "acceptanceCriteria": [ + "`TestE2E_MultipleConcurrentClients` test function", + "Starts echo server, creates Host, connects, adds echo port", + "Spawns 5 goroutines, each: connects via connectClientViaRelay, starts drainAndReplyRequests, waits for tcpip-forward, opens forwarded-tcpip, writes fmt.Sprintf(\"client-%d\", i), verifies echo", + "Uses sync.WaitGroup to wait for all goroutines", + "Collects errors via buffered channel", + "All 5 clients succeed independently", + "Test passes with `go test -v -run TestE2E_MultipleConcurrentClients ./go/tunnels/`" + ], + "priority": 14, + "passes": true, + "notes": "" + }, + { + "id": "US-014", + "title": "Test concurrent port operations with race detector", + "description": "As a developer, I want to verify thread safety of concurrent AddPort/RemovePort operations.", + "acceptanceCriteria": [ + "`TestE2E_ConcurrentPortOperations` test function", + "Creates Host, connects", + "Accesses SSH session via h.mu.Lock(); ssh := h.ssh; h.mu.Unlock()", + "Launches 10 goroutines: 0-4 add ports 9000-9004, 5-9 add then remove ports 9005-9009", + "sync.WaitGroup waits for all", + "Verifies ports 9000-9004 exist via ssh.HasPort()", + "Verifies ports 9005-9009 do NOT exist", + "Verifies len(ssh.Ports()) == 5", + "Must pass with `go test -race`", + "Test passes with `go test -v -race -run TestE2E_ConcurrentPortOperations ./go/tunnels/`" + ], + "priority": 15, + "passes": true, + "notes": "This test accesses internal h.mu and h.ssh fields. Must run with -race flag to validate mutex correctness." + }, + { + "id": "US-015", + "title": "Test host public key verification", + "description": "As a developer, I want to verify the client can capture and verify the host's SSH public key matches what was registered.", + "acceptanceCriteria": [ + "`TestE2E_HostPublicKeyVerification` test function", + "Creates Host, connects", + "Gets expectedKey := host.HostPublicKeyBase64()", + "connectClientWithKeyCapture(t, relay) returns sshConn, globalReqs, receivedKey", + "Encodes received key: base64.StdEncoding.EncodeToString(receivedKey.Marshal())", + "Asserts receivedKeyBase64 == expectedKey", + "Test passes with `go test -v -run TestE2E_HostPublicKeyVerification ./go/tunnels/`" + ], + "priority": 16, + "passes": true, + "notes": "" + }, + { + "id": "US-016", + "title": "Test IPv4 and IPv6 port forwarding", + "description": "As a developer, I want to verify port forwarding works over both IPv4 and IPv6.", + "acceptanceCriteria": [ + "`TestE2E_IPv4AndIPv6` test function", + "Starts IPv4 echo server on 127.0.0.1:0", + "Starts IPv6 echo server on [::1]:0 — t.Skip(\"IPv6 not available\") if listen fails", + "Creates Host, connects, adds both ports", + "Connects client, starts drainAndReplyRequests", + "Opens forwarded-tcpip to IPv4 port, sends \"ipv4-test\", verifies echo", + "Opens forwarded-tcpip to IPv6 port, sends \"ipv6-test\", verifies echo", + "Test passes with `go test -v -run TestE2E_IPv4AndIPv6 ./go/tunnels/`" + ], + "priority": 17, + "passes": true, + "notes": "" + }, + { + "id": "US-017", + "title": "Test connection refused graceful handling", + "description": "As a developer, I want to verify that when a port is forwarded but no local listener exists, the host handles it gracefully.", + "acceptanceCriteria": [ + "`TestE2E_ConnectionRefused` test function", + "Creates Host, connects", + "host.AddPort(ctx, &TunnelPort{PortNumber: 19999}) — no listener on this port", + "Connects client, waits for tcpip-forward", + "Opens forwarded-tcpip channel to port 19999", + "Either channel open fails cleanly, or channel opens but read returns io.EOF or error", + "No panic", + "Host remains functional", + "Test passes with `go test -v -run TestE2E_ConnectionRefused ./go/tunnels/`" + ], + "priority": 18, + "passes": true, + "notes": "" + }, + { + "id": "US-018", + "title": "Test client disconnect mid-transfer", + "description": "As a developer, I want to verify client disconnect during data transfer doesn't crash the host.", + "acceptanceCriteria": [ + "`TestE2E_ClientDisconnectMidTransfer` test function", + "Starts echo server, creates Host, connects, adds port", + "Connects client 1, opens forwarded-tcpip channel", + "Writes \"partial data\" to channel", + "Immediately closes client 1 SSH connection (abrupt disconnect)", + "No panic after 200ms sleep", + "Connects client 2, opens forwarded-tcpip channel", + "Writes \"full message\", verifies echo response", + "Test passes with `go test -v -run TestE2E_ClientDisconnectMidTransfer ./go/tunnels/`" + ], + "priority": 19, + "passes": true, + "notes": "" + }, + { + "id": "US-019", + "title": "Test Wait blocks until disconnect", + "description": "As a developer, I want to verify host.Wait() blocks until the relay drops, then returns promptly.", + "acceptanceCriteria": [ + "`TestE2E_WaitBlocksUntilDisconnect` test function", + "Creates Host, connects", + "Launches host.Wait() in goroutine, sends result to channel", + "After 200ms, verifies channel is empty (Wait is still blocking)", + "Closes relay server", + "Waits for result from channel with 5s timeout", + "Verifies Wait returned (non-nil error expected)", + "Test passes with `go test -v -run TestE2E_WaitBlocksUntilDisconnect ./go/tunnels/`" + ], + "priority": 20, + "passes": true, + "notes": "" + }, + { + "id": "US-020", + "title": "Test reconnection after relay drop", + "description": "As a developer, I want to verify that with EnableReconnect=true, the host automatically reconnects when the relay drops.", + "acceptanceCriteria": [ + "`TestE2E_Reconnection` test function", + "Creates two RelayHostServer instances", + "Creates e2eMockAPI initially pointing to relay 1", + "Sets host.EnableReconnect = true", + "Tracks status transitions via callback", + "Connects to relay 1, verifies connected", + "Creates relay 2, updates api.relayURI to relay 2 URL", + "Closes relay 1 to simulate disconnect", + "relay2.WaitForConnection(30s) — reconnection succeeds", + "Status transitions include: Connected → Disconnected → Connecting → Connected", + "Close host", + "Test passes with `go test -v -run TestE2E_Reconnection -timeout 60s ./go/tunnels/`" + ], + "priority": 21, + "passes": true, + "notes": "reconnect() uses 100ms base delay with 2x exponential backoff (max 12.8s). Test uses 30s timeout — generous but necessary." + }, + { + "id": "US-021", + "title": "Test token refresh during reconnection", + "description": "As a developer, I want to verify RefreshTunnelAccessTokenFunc is invoked when the management API returns 401 during reconnection.", + "acceptanceCriteria": [ + "`TestE2E_TokenRefresh` test function", + "Sets host.EnableReconnect = true", + "Sets host.RefreshTunnelAccessTokenFunc to callback that increments atomic counter and returns \"refreshed-token\", nil", + "Connects, verifies connected", + "Sets api.unauthorizedOnce = 1 (next UpdateTunnelEndpoint returns 401)", + "Creates relay 2, updates api.relayURI", + "Closes relay 1 to trigger reconnection", + "Reconnection: first attempt gets 401, calls token refresh, retries, succeeds", + "Waits for relay 2 connection", + "Asserts tokenRefreshCalls >= 1", + "Test passes with `go test -v -run TestE2E_TokenRefresh -timeout 60s ./go/tunnels/`" + ], + "priority": 22, + "passes": true, + "notes": "Depends on US-020 reconnection infrastructure. On 401, token refresh is attempted and retry happens immediately (no additional delay)." + } + ] +} diff --git a/progress.txt b/progress.txt new file mode 100644 index 00000000..89d7b831 --- /dev/null +++ b/progress.txt @@ -0,0 +1,350 @@ +## Codebase Patterns +- Use `tunnelstest.NewRelayHostServer()` with `WithHostAccessToken("Tunnel test-token")` for WebSocket-level host relay mocking +- Use `tcpConnPair(t)` for TCP loopback pairs in SSH tests — NEVER use `net.Pipe()` (deadlocks SSH handshakes) +- Mock management API pattern: `httptest.NewServer` routing on method + URL path segments (`/endpoints/`, `/ports/`, `/tunnels/`) +- Two-level SSH: Host connects as SSH client to relay; each `client-ssh-session-stream` channel gets a nested SSH server +- AddPort/RemovePort send SSH global requests with `wantReply=true` — must drain+reply to avoid deadlock +- Test assertions use standard `testing.T` with `t.Fatal`/`t.Fatalf` — no assertion libraries +- Host tests are in package `tunnels` (same package) — can access internal fields like `h.mu`, `h.ssh` +- `ConnectionStatus` enum: None(0), Connecting(1), Connected(2), Disconnected(3) +- Port forwarding test flow: start echo server → setup host+relay → AddPort → connect client → wait for tcpip-forward → open channel → write/read/verify + +--- + +## 2026-02-27 - US-000 +- Verified Go development environment +- Go 1.26.0 installed at /opt/homebrew/bin/go (darwin/arm64) +- Dependencies verified: gorilla/websocket v1.4.2, golang.org/x/crypto v0.23.0 +- `go vet ./go/tunnels/...` passes clean +- `go test -short -v ./go/tunnels/ssh/...` passes — all 42 tests pass (3.69s) +- `go test -race -short -run TestChannelConn ./go/tunnels/ssh/` passes +- No files changed — environment verification only +- **Learnings for future iterations:** + - Go 1.26.0 is the runtime version; go.mod specifies go 1.17 for compat + - All existing SSH and host tests pass, providing a solid foundation for E2E tests + - The race detector works correctly with the test suite +--- + +## 2026-02-27 - US-001 +- Created shared E2E test infrastructure in `go/tunnels/host_e2e_test.go` +- Implemented `e2eMockAPI` struct with atomic counters, dynamic relay URI, remote ports, and 401 simulation +- Implemented HTTP handler routes for PUT/DELETE endpoints, GET tunnels, PUT/DELETE ports +- Implemented `connectClientViaRelay(t, relay)` — SimulateClientArrival + nested SSH handshake +- Implemented `connectClientWithKeyCapture(t, relay)` — same with HostKeyCallback capture +- Implemented `startEchoServerE2E(t)` — TCP echo server with t.Cleanup +- Implemented `openForwardedChannel(t, sshConn, port)` — forwarded-tcpip channel helper +- Implemented `drainAndReplyRequests(globalReqs, collected)` — background goroutine that replies to SSH global requests +- Files changed: go/tunnels/host_e2e_test.go (new) +- `go vet` and `go build` pass clean +- **Learnings for future iterations:** + - E2E helpers use `t.Cleanup` for automatic resource cleanup instead of defer in callers + - `connectClientViaRelay` and `connectClientWithKeyCapture` register SSH conn cleanup via t.Cleanup + - `e2eMockAPI.relayURI` and `remotePorts` use `atomic.Value` for thread-safe dynamic updates + - `drainAndReplyRequests` drops requests if collected channel buffer is full (non-blocking send) + - The `testForwardedTCPIPData` struct from host_test.go is reused (same package) +--- + +## 2026-02-27 - US-002 +- Implemented `TestE2E_HostLifecycle` in `go/tunnels/host_e2e_test.go` +- Verifies full host lifecycle: construction → connection status None → connect → Connected → close → Disconnected → idempotent close +- Asserts HostPublicKeyBase64() is non-empty and valid base64 +- Asserts deleteEndpointCalls == 1 after first close, still 1 after idempotent close +- Files changed: go/tunnels/host_e2e_test.go (added imports + test function) +- `go vet`, `go test -short ./go/tunnels/...` all pass +- **Learnings for future iterations:** + - E2E test pattern: create relay → create e2eMockAPI(t, relay.URL()) → NewHost(logger, api.manager) → Connect → assert → Close + - Tunnel struct needs AccessTokens map with TunnelAccessScopeHost key matching relay token (without "Tunnel " prefix — host.Connect prepends it) + - t.Cleanup(func() { relay.Close() }) is preferred over defer for relay cleanup in E2E tests +--- + +## 2026-02-27 - US-003 +- Implemented `TestE2E_ErrorHandling` with 11 table-driven subtests in `go/tunnels/host_e2e_test.go` +- Subtests: ErrNoManager, ErrAlreadyConnected, ErrNotConnected_Close/Wait/AddPort/RemovePort/RefreshPorts, ErrPortAlreadyAdded, ErrTooManyConnections, ErrNoHostRelayURI, 409Conflict_Tolerated +- Added `portConflictOnce int32` field to `e2eMockAPI` for one-shot 409 simulation on port creation +- Added imports: `errors`, `tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh"` +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - All error sentinels (ErrNoManager, ErrAlreadyConnected, etc.) are returned directly (not wrapped), so errors.Is() works without unwrapping + - For ErrTooManyConnections test: set `host.disconnectReason = tunnelssh.SshDisconnectReasonTooManyConnections` directly (same package access) + - For ErrNoHostRelayURI test: use `newE2EMockAPI(t, "")` to return an empty HostRelayURI from the mock + - AddPort tolerates 409 Conflict via `errors.As(err, &tunnelErr)` + StatusCode check — the port is still added to the SSH session + - ErrNotConnected subtests don't need a relay — just a host created with `NewHost` but never connected +--- + +## 2026-02-27 - US-004 +- Implemented `TestE2E_ConnectionStatusCallbacks` in `go/tunnels/host_e2e_test.go` +- Registers `ConnectionStatusChanged` callback that appends `ConnectionStatus` values to a mutex-guarded slice +- Asserts initial status is `ConnectionStatusNone` +- After Connect: verifies Connecting, Connected transitions +- After Close: verifies Disconnected transition +- Asserts exactly 3 transitions with no duplicates or extras +- Files changed: go/tunnels/host_e2e_test.go (added `sync` import + test function) +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - `setConnectionStatus` fires the callback with (prev, curr) only when prev != curr — no duplicate transitions + - The callback is read under mutex but invoked outside the lock, so use a separate sync.Mutex to guard the transitions slice + - `ConnectionStatusChanged` must be set before `Connect()` to capture the Connecting transition +--- + +## 2026-02-27 - US-005 +- Implemented `TestE2E_PortForwardingDataFlow` in `go/tunnels/host_e2e_test.go` +- Verifies the complete data path: echo server → host → relay → client → forwarded-tcpip channel → echo response +- Starts TCP echo server, creates relay+mock API+host, connects, adds port +- Connects client via relay, waits for tcpip-forward global request, replies true +- Opens forwarded-tcpip channel, writes "hello e2e tunnel", verifies echo matches +- Files changed: go/tunnels/host_e2e_test.go (added test function) +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - Port forwarding data flow test follows the same pattern as TestHostAndClientIntegration in host_test.go but uses relay infrastructure + - The client must explicitly wait for and reply to the tcpip-forward global request before opening a forwarded-tcpip channel + - openForwardedChannel helper handles the forwarded-tcpip channel setup including discarding channel requests + - The bidirectional copy in handleForwardedTCPIP correctly propagates CloseWrite() for clean echo server interaction +--- + +## 2026-02-27 - US-006 +- Implemented `TestE2E_DirectTcpipAndForwardedTcpip` in `go/tunnels/host_e2e_test.go` +- Tests three scenarios: forwarded-tcpip echo, direct-tcpip echo, and rejection of unregistered port +- Opens forwarded-tcpip channel, sends "forwarded-test", verifies echo response +- Opens direct-tcpip channel to same port, sends "direct-test", verifies echo response +- Opens direct-tcpip to unregistered port (echoPort + 1000), asserts *ssh.OpenChannelError with reason ssh.Prohibited +- Uses drainAndReplyRequests to handle wantReply=true tcpip-forward from AddPort +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - Both forwarded-tcpip and direct-tcpip use the same `forwardedTCPIPData` struct and same handler path in `handleClientChannels` + - To test direct-tcpip, use `sshConn.OpenChannel("direct-tcpip", data)` with the same marshaled data as forwarded-tcpip + - Channel rejection returns `*ssh.OpenChannelError` with `.Reason` field matching SSH rejection reason codes (e.g., `ssh.Prohibited`) + - The existing `openForwardedChannel` helper only supports "forwarded-tcpip"; for "direct-tcpip", construct the channel open inline +--- + +## 2026-02-27 - US-007 +- Implemented `TestE2E_MultiplePorts` in `go/tunnels/host_e2e_test.go` +- Starts 3 echo servers on 3 different ports (portA, portB, portC) +- Adds all 3 ports via host.AddPort, then connects client and receives 3 tcpip-forward requests +- For each port: opens forwarded-tcpip channel, sends unique message ("port-A", "port-B", "port-C"), verifies echo +- Confirms no cross-contamination between ports +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - When adding multiple ports before connecting a client, all tcpip-forward requests arrive when the client connects — can receive them in a simple loop + - Each port gets its own independent forwarded-tcpip channel; no shared state between port channels + - The pattern of add-ports → connect-client → wait-for-N-forwards → test-each-port is clean and reusable for multi-port scenarios +--- + +## 2026-02-27 - US-008 +- Implemented `TestE2E_DynamicPortManagement` in `go/tunnels/host_e2e_test.go` +- Tests adding/removing ports while a client is already connected with real-time SSH notifications +- Starts echo server A, connects host, adds port A, connects client with drainAndReplyRequests +- Waits for tcpip-forward for port A, verifies data flows through port A +- Starts echo server B, dynamically adds port B, waits for tcpip-forward for port B, verifies data flow +- Removes port A via host.RemovePort, waits for cancel-tcpip-forward in collected channel +- Attempts to open forwarded-tcpip to removed port A — asserts *ssh.OpenChannelError with ssh.Prohibited +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - RemovePort sends "cancel-tcpip-forward" via cancelForwardPortToClient to all connected clients (wantReply=true — must drain+reply) + - After RemovePort, the port is removed from the SSH session's port list, so forwarded-tcpip channels to that port are rejected with ssh.Prohibited + - Dynamic port add/remove while client is connected works seamlessly — client receives tcpip-forward and cancel-tcpip-forward in real time + - The collected channel from drainAndReplyRequests receives both tcpip-forward and cancel-tcpip-forward requests in order +--- + +## 2026-02-27 - US-009 +- Implemented `TestE2E_PortDuplicateHandling` in `go/tunnels/host_e2e_test.go` +- Creates Host, connects, adds port 8080, then tries to add 8080 again — asserts ErrPortAlreadyAdded +- Accesses SSH session via `host.mu.Lock(); sshSession := host.ssh; host.mu.Unlock()` +- Verifies `sshSession.Ports()` contains 8080 exactly once (no duplicates) +- Verifies `api.createPortCalls == 1` (duplicate was rejected before reaching the API) +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - ErrPortAlreadyAdded is returned before calling the management API — so createPortCalls stays at 1 + - Accessing internal fields (host.mu, host.ssh) from tests in the same package is straightforward and used for SSH session inspection + - sshSession.Ports() returns a copy of the port list — safe to iterate without holding locks +--- + +## 2026-02-27 - US-010 +- Implemented `TestE2E_RefreshPorts` in `go/tunnels/host_e2e_test.go` +- Creates Host, connects, then manually adds port 5000 to the SSH session +- Configures mock API to return remote ports 3000 and 4000 +- Calls host.RefreshPorts(ctx) and asserts: 3000 present, 4000 present, 5000 removed +- Calls RefreshPorts again to verify idempotency — all assertions still hold +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass +- **Learnings for future iterations:** + - RefreshPorts calls sshSession.AddPort/RemovePort directly (not Host.AddPort/RemovePort), so no management API calls for individual ports + - When no clients are connected, sshSession.AddPort/RemovePort only update the internal port list — no SSH global requests sent, no deadlock risk + - The mock API's GET tunnels handler returns `api.remotePorts` as the tunnel's Ports field, which RefreshPorts reads via `h.manager.GetTunnel` with `IncludePorts: true` + - RefreshPorts also updates `h.tunnel.Ports` with the refreshed ports under lock +--- + +## 2026-02-27 - US-011 +- Implemented `TestE2E_LargeDataTransfer` in `go/tunnels/host_e2e_test.go` +- Generates 1MB payload with deterministic content: payload[i] = byte(i % 256) +- Computes SHA256 hash before sending, compares with hash of echoed response +- Uses goroutine for write + CloseWrite(), main thread reads via io.Copy +- Asserts received length == 1048576 and SHA256 hashes match +- Added `bytes` and `crypto/sha256` imports +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all E2E tests pass (10/10) +- **Learnings for future iterations:** + - The bidirectional copy in handleForwardedTCPIP properly propagates CloseWrite() — this is essential for large data echo tests + - Use ch.CloseWrite() after writing (not ch.Close()) to signal EOF to the echo server while keeping the read side open + - io.Copy on the read side blocks until the echo server echoes all data and the write side is closed + - 1MB transfers complete in ~20ms through the relay+SSH tunnel stack — very fast +--- + +## 2026-02-27 - US-012 +- Implemented `TestE2E_BidirectionalStreaming` in `go/tunnels/host_e2e_test.go` +- Sends 10 sequential messages of increasing size (100, 200, ..., 1000 bytes) through a forwarded-tcpip channel +- Each message uses deterministic content: `bytes.Repeat([]byte{byte(i)}, size)` for easy verification +- Reads each echo response via `io.ReadFull` and verifies with `bytes.Equal` +- All 10 messages echo correctly with no ordering issues +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 12 E2E tests pass (0.32s total) +- **Learnings for future iterations:** + - Sequential request-response over a single SSH channel works reliably — no need for separate channels per message + - The echo server pattern (io.Copy) handles sequential writes correctly — each read gets exactly the bytes written + - `io.ReadFull` is the right choice for fixed-size reads in sequential streaming scenarios +--- + +## 2026-02-27 - US-013 +- Implemented `TestE2E_MultipleConcurrentClients` in `go/tunnels/host_e2e_test.go` +- Spawns 5 goroutines, each connecting via `connectClientViaRelay`, draining+replying global requests, waiting for tcpip-forward, opening forwarded-tcpip channel, and verifying echo with unique message `fmt.Sprintf("client-%d", i)` +- Uses `sync.WaitGroup` to wait for all goroutines and buffered error channel to collect failures +- All 5 clients succeed independently with no cross-contamination +- Added `"fmt"` import for goroutine error formatting +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 13 E2E tests pass (0.31s total) +- **Learnings for future iterations:** + - `SimulateClientArrival()` can be called multiple times on the same relay — each call opens a new `client-ssh-session-stream` channel + - `connectClientViaRelay` uses `t.Fatalf` which is not goroutine-safe, but the test helper calls happen early and sequentially in practice; for strict safety, error reporting uses the `errs` channel instead of t.Fatalf in the goroutine body + - Each concurrent client independently receives its own tcpip-forward global request when connecting — the host sends port forwards to each new client session +--- + +## 2026-02-27 - US-014 +- Implemented `TestE2E_ConcurrentPortOperations` in `go/tunnels/host_e2e_test.go` +- Launches 10 goroutines: goroutines 0-4 add ports 9000-9004 (persist), goroutines 5-9 add then remove ports 9005-9009 +- Uses sync.WaitGroup to wait for all goroutines to complete +- Verifies ports 9000-9004 exist via sshSession.HasPort() +- Verifies ports 9005-9009 do NOT exist (were added then removed) +- Verifies len(sshSession.Ports()) == 5 exactly +- Passes with `go test -v -race` — validates mutex correctness under concurrent access +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 14 E2E tests pass with -race flag (1.53s total) +- **Learnings for future iterations:** + - HostSSHSession.AddPort/RemovePort use portsMu (RWMutex) and are safe for concurrent use — no race conditions + - For concurrent port tests, use sshSession.AddPort/RemovePort directly (not Host.AddPort/RemovePort) to avoid needing a connected client to drain SSH global requests + - The race detector confirms the locking strategy in HostSSHSession is correct for concurrent port operations +--- + +## 2026-02-27 - US-015 +- Implemented `TestE2E_HostPublicKeyVerification` in `go/tunnels/host_e2e_test.go` +- Creates Host, connects, gets `HostPublicKeyBase64()` as the expected key +- Uses `connectClientWithKeyCapture(t, relay)` to receive the host's SSH public key during handshake +- Encodes received key via `base64.StdEncoding.EncodeToString(receivedKey.Marshal())` and asserts it matches +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 15 E2E tests pass (0.40s total) +- **Learnings for future iterations:** + - `connectClientWithKeyCapture` uses a custom `HostKeyCallback` that captures the key during the nested SSH handshake — the key is available immediately after `ssh.NewClientConn` returns + - `HostPublicKeyBase64()` returns the base64-encoded marshaled form of the host's SSH public key, which matches `base64.StdEncoding.EncodeToString(receivedKey.Marshal())` + - This is a lightweight test — no port forwarding or data flow needed, just the SSH handshake +--- + +## 2026-02-27 - US-016 +- Implemented `TestE2E_IPv4AndIPv6` in `go/tunnels/host_e2e_test.go` +- Starts IPv4 echo server on `127.0.0.1:0` via existing `startEchoServerE2E(t)` +- Starts IPv6 echo server on `[::1]:0` — skips test with `t.Skip("IPv6 not available")` if listen fails +- Creates host, connects, adds both ports +- Connects client with `drainAndReplyRequests` to handle both tcpip-forward requests +- Opens forwarded-tcpip to IPv4 port, sends "ipv4-test", verifies echo +- Opens forwarded-tcpip to IPv6 port, sends "ipv6-test", verifies echo +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 16 E2E tests pass (0.33s total) +- **Learnings for future iterations:** + - `handleForwardedTCPIP` in host_session.go tries IPv4 first (`tcp4` dial to `127.0.0.1`), then falls back to IPv6 (`tcp6` dial to `[::1]`) — this is how IPv6 port forwarding works + - To test IPv6 specifically, start the echo server on `[::1]:0` only (using `net.Listen("tcp6", "[::1]:0")`) — the IPv4 dial will fail and the fallback to IPv6 kicks in + - The IPv6 echo server uses inline setup (not `startEchoServerE2E`) since the helper is hardcoded to `127.0.0.1:0` + - `drainAndReplyRequests` is needed because AddPort sends wantReply=true tcpip-forward for each port added +--- + +## 2026-02-27 - US-017 +- Implemented `TestE2E_ConnectionRefused` in `go/tunnels/host_e2e_test.go` +- Adds port 19999 (no listener), connects client, opens forwarded-tcpip channel +- Host accepts channel but local dial fails → channel is closed → client gets EOF on read +- Verifies no panic and host remains functional by adding an echo port afterward and testing data flow +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 17 E2E tests pass (0.33s total) +- **Learnings for future iterations:** + - When `handleForwardedTCPIP` fails to dial locally, it logs the error and calls `channel.Close()` — the client sees EOF on read + - When testing host functionality after a failure, must start `drainAndReplyRequests` BEFORE calling `AddPort` — otherwise wantReply=true deadlocks + - The host gracefully handles connection-refused with no crash/panic — it logs the error and continues accepting new channels +--- + +## 2026-02-27 - US-018 +- Implemented `TestE2E_ClientDisconnectMidTransfer` in `go/tunnels/host_e2e_test.go` +- Starts echo server, creates relay+mock API+host, connects, adds echo port +- Connects client 1 via relay with drainAndReplyRequests, waits for tcpip-forward +- Opens forwarded-tcpip channel, writes "partial data", then immediately closes client 1 SSH connection (abrupt disconnect) +- Sleeps 200ms to confirm no panic +- Connects client 2 via relay with drainAndReplyRequests, waits for tcpip-forward +- Opens forwarded-tcpip channel, writes "full message", verifies echo response — host is still functional +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 18 E2E tests pass (0.55s total) +- **Learnings for future iterations:** + - Abrupt client disconnect (sshConn.Close()) during an active data transfer does not crash the host — the host handles broken pipe/EOF errors gracefully in handleForwardedTCPIP + - After a client disconnects, new clients can connect and forward data through the same ports without issues + - The 200ms sleep is sufficient to confirm no panic — the host processes the disconnect asynchronously + - Use drainAndReplyRequests for both client 1 and client 2 — AddPort was already called before client 1 connected, so the tcpip-forward request arrives when each client connects +--- + +## 2026-02-27 - US-019 +- Implemented `TestE2E_WaitBlocksUntilDisconnect` in `go/tunnels/host_e2e_test.go` +- Creates Host, connects, launches `host.Wait()` in goroutine, verifies it blocks +- After 200ms, confirms Wait has not returned (channel is empty) +- Closes relay server to simulate relay drop +- Waits for Wait to return with 5s timeout, asserts non-nil error +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 19 E2E tests pass (0.78s total) +- **Learnings for future iterations:** + - `Host.Wait()` blocks on `sshSession.Wait()` which blocks on the SSH `conn.Wait()` — it unblocks when the underlying connection breaks + - Closing the relay server (`relay.Close()`) closes both the SSH connection and the HTTP server, causing `Wait()` to return with a non-nil error + - No `defer host.Close()` needed in this test since `Wait()` returns after relay close, and we want to observe the natural disconnect behavior + - The 200ms sleep for the "still blocking" check is sufficient — the relay drop causes Wait to return in <100ms +--- + +## 2026-02-27 - US-020 +- Implemented `TestE2E_Reconnection` in `go/tunnels/host_e2e_test.go` +- Creates two `RelayHostServer` instances; connects to relay1, then simulates disconnect by closing relay1 +- Sets `EnableReconnect = true` and tracks status transitions via `ConnectionStatusChanged` callback +- After relay1 closes, `Wait()` detects disconnect and calls `reconnect()` with exponential backoff +- `reconnect()` calls `connectOnce()` which reads updated `api.relayURI` (now pointing to relay2) and connects +- `relay2.WaitForConnection(30s)` confirms reconnection succeeded +- Polls for `ConnectionStatusConnected` after relay2 connection (relay connects before host finishes `connectOnce`) +- Verifies status transitions include: Connected → Disconnected → Connecting → Connected +- Closes host and waits for `Wait()` goroutine to return +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 20 E2E tests pass with -race flag (2.2s total) +- **Learnings for future iterations:** + - `relay.WaitForConnection()` returns when the relay-side SSH handshake completes, but the host's `connectOnce` may still be in progress — poll for `ConnectionStatusConnected` instead of asserting immediately + - `reconnect()` uses 100ms base delay with 2x exponential backoff (max 12.8s); first reconnection attempt happens ~100ms after disconnect + - When `host.Close()` is called while `Wait()` is in its reconnection loop, `cancel()` cancels `h.ctx`, causing `reconnect()` to return `context.Canceled` on the next delay check + - The `e2eMockAPI.relayURI` atomic.Value allows seamlessly switching relay URLs between reconnection attempts + - `RelayHostServer.handleConnection` properly handles reconnect scenarios by resetting the `connected` channel +--- + +## 2026-02-27 - US-021 +- Implemented `TestE2E_TokenRefresh` in `go/tunnels/host_e2e_test.go` +- Sets `EnableReconnect = true` and `RefreshTunnelAccessTokenFunc` callback that increments atomic counter and returns "refreshed-token" +- Connects to relay 1, sets `api.unauthorizedOnce = 1` so next `UpdateTunnelEndpoint` returns 401 +- Creates relay 2 with `WithHostAccessToken("Tunnel refreshed-token")` to accept the refreshed token +- Closes relay 1 to trigger reconnection; first attempt gets 401, calls token refresh, retries with new token, succeeds +- Asserts `tokenRefreshCalls >= 1` to confirm the callback was invoked +- Files changed: go/tunnels/host_e2e_test.go +- `go vet` and all 21 E2E tests pass with -race flag (2.48s total) +- **Learnings for future iterations:** + - `reconnect()` catches 401 via `errors.As(err, &tunnelErr)` with `StatusCode == 401`, then calls `refreshAccessToken()` which invokes the callback + - After token refresh, `reconnect()` retries immediately with `continue` (same delay, no additional backoff) — the 401 path has two reconnection log lines at the same delay + - The relay 2 must accept the refreshed token ("Tunnel refreshed-token"), not the original token — use `WithHostAccessToken("Tunnel refreshed-token")` when creating relay 2 + - The `unauthorizedOnce` atomic flag is consumed via `CompareAndSwapInt32` — only the first `UpdateTunnelEndpoint` call after setting it returns 401 +---