Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions go/tunnels/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ import (
tunnelstest "github.com/microsoft/dev-tunnels/go/tunnels/test"
)

func verboseLogger(t *testing.T) *log.Logger {
loggerOutput := io.Discard
if testing.Verbose() {
loggerOutput = os.Stdout
}
return log.New(loggerOutput, "", log.LstdFlags)
}

func TestSuccessfulConnect(t *testing.T) {
accessToken := "tunnel access-token"
relayServer, err := tunnelstest.NewRelayServer(
Expand All @@ -45,7 +53,7 @@ func TestSuccessfulConnect(t *testing.T) {
},
}

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
done := make(chan error)
go func() {
c, err := NewClient(logger, &tunnel, true)
Expand Down Expand Up @@ -95,7 +103,7 @@ func TestReturnsErrWithInvalidAccessToken(t *testing.T) {
},
}

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
c, _ := NewClient(logger, &tunnel, true)
err = c.Connect(ctx, "")
if err == nil {
Expand All @@ -104,15 +112,15 @@ func TestReturnsErrWithInvalidAccessToken(t *testing.T) {
}

func TestReturnsErrWhenTunnelIsNil(t *testing.T) {
logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
_, err := NewClient(logger, nil, true)
if err == nil {
t.Error("expected error, got nil")
}
}

func TestReturnsErrWhenEndpointsAreNil(t *testing.T) {
logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
tunnel := Tunnel{}
_, err := NewClient(logger, &tunnel, true)
if err == nil {
Expand All @@ -129,7 +137,7 @@ func TestReturnsErrWhenTunnelEndpointsDontMatchHostID(t *testing.T) {
},
}

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
c, _ := NewClient(logger, &tunnel, true)
err := c.Connect(ctx, "host2")
if err == nil {
Expand All @@ -149,7 +157,7 @@ func TestReturnsErrWhenEndpointGroupsContainMultipleHosts(t *testing.T) {
},
}

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
c, _ := NewClient(logger, &tunnel, true)
err := c.Connect(ctx, "host1")
if err == nil {
Expand All @@ -169,7 +177,7 @@ func TestReturnsErrWhenThereAreMoreThanOneEndpoints(t *testing.T) {
},
}

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
c, _ := NewClient(logger, &tunnel, true)
err := c.Connect(ctx, "")
if err == nil {
Expand Down Expand Up @@ -216,7 +224,7 @@ func TestPortForwarding(t *testing.T) {
ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
defer cancel()

logger := log.New(os.Stdout, "", log.LstdFlags)
logger := verboseLogger(t)
done := make(chan error)
go func() {
c, err := NewClient(logger, &tunnel, true)
Expand Down Expand Up @@ -295,3 +303,4 @@ func TestPortForwarding(t *testing.T) {
}
}
}

68 changes: 57 additions & 11 deletions go/tunnels/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -71,6 +72,15 @@ type UserAgent struct {
Version string
}

type requestError struct {
statusCode int
message string
}

func (e *requestError) Error() string {
return e.message
}

// Manager is used to interact with the Visual Studio Tunnel Service APIs.
type Manager struct {
tokenProvider tokenProviderfn
Expand Down Expand Up @@ -182,8 +192,10 @@ func (m *Manager) CreateTunnel(ctx context.Context, tunnel *Tunnel, options *Tun
if tunnel == nil {
return nil, fmt.Errorf("tunnel must be provided")
}
idGenerated := false
if tunnel.TunnelID == "" {
tunnel.TunnelID = generateTunnelId()
idGenerated = true
}

if options == nil {
Expand All @@ -194,10 +206,6 @@ func (m *Manager) CreateTunnel(ctx context.Context, tunnel *Tunnel, options *Tun
}
options.AdditionalHeaders["If-Not-Match"] = "*"

url, err := m.buildTunnelSpecificUri(tunnel, "", options, "", true)
if err != nil {
return nil, fmt.Errorf("error creating request url: %w", err)
}
convertedTunnel, err := tunnel.requestObject()
convertedTunnel.TunnelID = tunnel.TunnelID
if err != nil {
Expand All @@ -206,11 +214,23 @@ func (m *Manager) CreateTunnel(ctx context.Context, tunnel *Tunnel, options *Tun
var response []byte

for i := 0; i < createNameRetries; i++ {
response, err = m.sendTunnelRequest(ctx, tunnel, options, http.MethodPut, url, convertedTunnel, nil, manageAccessTokenScope, false)
url, err := m.buildTunnelSpecificUri(tunnel, "", options, "", true)
if err != nil {
convertedTunnel.TunnelID = generateTunnelId()
tunnel.TunnelID = convertedTunnel.TunnelID
return nil, fmt.Errorf("error creating request url: %w", err)
}
response, err = m.sendTunnelRequest(ctx, tunnel, options, http.MethodPut, url, convertedTunnel, nil, manageAccessTokenScope, false)
if err == nil {
break
}
if !idGenerated {
break
}
var requestErr *requestError
if !errors.As(err, &requestErr) || requestErr.statusCode != http.StatusConflict {
break
}
convertedTunnel.TunnelID = generateTunnelId()
tunnel.TunnelID = convertedTunnel.TunnelID
}
if err != nil {
return nil, fmt.Errorf("error sending create tunnel request: %w", err)
Expand All @@ -232,6 +252,26 @@ func (m *Manager) UpdateTunnel(ctx context.Context, tunnel *Tunnel, updateFields
return nil, fmt.Errorf("tunnel must be provided")
}

if len(updateFields) > 0 {
// TunnelID and ClusterID are not updatable but must be supplied in the update body.
needTunnelID := true
needClusterID := true
for _, field := range updateFields {
if field == "TunnelID" {
needTunnelID = false
}
if field == "ClusterID" {
needClusterID = false
}
}
if needTunnelID {
updateFields = append(updateFields, "TunnelID")
}
if needClusterID {
updateFields = append(updateFields, "ClusterID")
}
}

if options == nil {
options = &TunnelRequestOptions{}
}
Expand Down Expand Up @@ -727,11 +767,17 @@ func (m *Manager) sendRequest(
if result.StatusCode > 300 {
errorMessage, err := m.readProblemDetails(result)
if err == nil && errorMessage != nil {
return nil, fmt.Errorf("unsuccessful request, response: %d %s\n\t%s",
result.StatusCode, http.StatusText(result.StatusCode), *errorMessage)
return nil, &requestError{
statusCode: result.StatusCode,
message: fmt.Sprintf("unsuccessful request, response: %d %s\n\t%s",
result.StatusCode, http.StatusText(result.StatusCode), *errorMessage),
}
} else {
return nil, fmt.Errorf("unsuccessful request, response: %d: %s",
result.StatusCode, http.StatusText(result.StatusCode))
return nil, &requestError{
statusCode: result.StatusCode,
message: fmt.Sprintf("unsuccessful request, response: %d: %s\n\t%s",
result.StatusCode, http.StatusText(result.StatusCode), err.Error()),
}
}
}

Expand Down
Loading
Loading