Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions bundle/run/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ func (r *jobRunner) logFailedTasks(ctx context.Context, runId int64) {
log.Errorf(ctx, "task %s failed. Unable to fetch error trace: %s", red(task.TaskKey), err)
continue
}
if progressLogger, ok := cmdio.FromContext(ctx); ok {
progressLogger.Log(progress.NewTaskErrorEvent(task.TaskKey, taskInfo.Error, taskInfo.ErrorTrace))
}
cmdio.Log(ctx, progress.NewTaskErrorEvent(task.TaskKey, taskInfo.Error, taskInfo.ErrorTrace))
log.Errorf(ctx, "Task %s failed!\nError:\n%s\nTrace:\n%s",
red(task.TaskKey), taskInfo.Error, taskInfo.ErrorTrace)
} else {
Expand All @@ -89,9 +87,8 @@ func (r *jobRunner) logFailedTasks(ctx context.Context, runId int64) {
// jobRunMonitor tracks state for a single job run and provides callbacks
// for monitoring progress.
type jobRunMonitor struct {
ctx context.Context
prevState *jobs.RunState
progressLogger *cmdio.Logger
ctx context.Context
prevState *jobs.RunState
}

// onProgress is the single callback that handles all state tracking and logging.
Expand All @@ -104,7 +101,7 @@ func (m *jobRunMonitor) onProgress(info *jobs.Run) {
// First time we see this run.
if m.prevState == nil {
log.Infof(m.ctx, "Run available at %s", info.RunPageUrl)
m.progressLogger.Log(progress.NewJobRunUrlEvent(info.RunPageUrl))
cmdio.Log(m.ctx, progress.NewJobRunUrlEvent(info.RunPageUrl))
}

// No state change: do not log.
Expand All @@ -125,7 +122,7 @@ func (m *jobRunMonitor) onProgress(info *jobs.Run) {
RunName: info.RunName,
State: *info.State,
}
m.progressLogger.Log(event)
cmdio.Log(m.ctx, event)
log.Info(m.ctx, event.String())
}

Expand All @@ -151,15 +148,8 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e

w := r.bundle.WorkspaceClient()

// callback to log progress events. Called on every poll request
progressLogger, ok := cmdio.FromContext(ctx)
if !ok {
return nil, errors.New("no progress logger found")
}

monitor := &jobRunMonitor{
ctx: ctx,
progressLogger: progressLogger,
ctx: ctx,
}

waiter, err := w.Jobs.RunNow(ctx, *req)
Expand All @@ -171,7 +161,7 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e
details, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{
RunId: waiter.RunId,
})
progressLogger.Log(progress.NewJobRunUrlEvent(details.RunPageUrl))
cmdio.Log(ctx, progress.NewJobRunUrlEvent(details.RunPageUrl))
return nil, err
}

Expand Down
3 changes: 0 additions & 3 deletions bundle/run/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -160,7 +159,6 @@ func TestJobRunnerRestart(t *testing.T) {
b.SetWorkpaceClient(m.WorkspaceClient)

ctx := cmdio.MockDiscard(context.Background())
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
Expand Down Expand Up @@ -231,7 +229,6 @@ func TestJobRunnerRestartForContinuousUnpausedJobs(t *testing.T) {
b.SetWorkpaceClient(m.WorkspaceClient)

ctx := cmdio.MockDiscard(context.Background())
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

jobApi := m.GetMockJobsAPI()

Expand Down
8 changes: 2 additions & 6 deletions bundle/run/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,9 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp

// setup progress logger and tracker to query events
updateTracker := progress.NewUpdateTracker(pipelineID, updateID, w)
progressLogger, ok := cmdio.FromContext(ctx)
if !ok {
return nil, errors.New("no progress logger found")
}

// Log the pipeline update URL as soon as it is available.
progressLogger.Log(progress.NewPipelineUpdateUrlEvent(w.Config.Host, updateID, pipelineID))
cmdio.Log(ctx, progress.NewPipelineUpdateUrlEvent(w.Config.Host, updateID, pipelineID))

if opts.NoWait {
return &output.PipelineOutput{
Expand All @@ -129,7 +125,7 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
return nil, err
}
for _, event := range events {
progressLogger.Log(&event)
cmdio.Log(ctx, &event)
log.Info(ctx, event.String())
}

Expand Down
2 changes: 0 additions & 2 deletions bundle/run/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
sdk_config "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines"
Expand Down Expand Up @@ -76,7 +75,6 @@ func TestPipelineRunnerRestart(t *testing.T) {
b.SetWorkpaceClient(m.WorkspaceClient)

ctx := cmdio.MockDiscard(context.Background())
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
Expand Down
3 changes: 2 additions & 1 deletion bundle/statemgmt/state_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
mockfiler "github.com/databricks/cli/internal/mocks/libs/filer"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/filer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -51,7 +52,7 @@ func TestStatePush(t *testing.T) {
identityFiler(mock),
}

ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
b := statePushTestBundle(t)

// Write a stale local state file.
Expand Down
5 changes: 1 addition & 4 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/databricks/cli/internal/build"
"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/log"
"github.com/databricks/cli/libs/telemetry"
Expand Down Expand Up @@ -140,9 +139,7 @@ Stack Trace:
// Run the command
cmd, err = cmd.ExecuteContextC(ctx)
if err != nil && !errors.Is(err, ErrAlreadyPrinted) {
// If cmdio logger initialization succeeds, then this function logs with the
// initialized cmdio logger, otherwise with the default cmdio logger
cmdio.LogError(cmd.Context(), err)
fmt.Fprintf(cmd.ErrOrStderr(), "Error: %s\n", err.Error())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for a dependency on cmdio here.

The previous implementation wrote directly to os.Stderr.

}

// Log exit status and error
Expand Down
5 changes: 3 additions & 2 deletions experimental/ssh/internal/proxy/client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ import (
"testing"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration) *httptest.Server {
ctx := t.Context()
ctx := cmdio.MockDiscard(t.Context())
connections := NewConnectionsManager(maxClients, shutdownDelay)
proxyServer := NewProxyServer(ctx, connections, func(ctx context.Context) *exec.Cmd {
// 'cat' command reads each line from stdin and sends it to stdout, so we can test end-to-end proxying.
Expand All @@ -30,7 +31,7 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
}

func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) (io.WriteCloser, *testBuffer) {
ctx := t.Context()
ctx := cmdio.MockDiscard(t.Context())
clientInput, clientInputWriter := io.Pipe()
clientOutput := newTestBuffer(t)
wsURL := "ws" + serverURL[4:]
Expand Down
13 changes: 7 additions & 6 deletions experimental/ssh/internal/setup/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import (
"testing"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidateClusterAccess_SingleUser(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
m := mocks.NewMockWorkspaceClient(t)
clustersAPI := m.GetMockClustersAPI()

Expand All @@ -29,7 +30,7 @@ func TestValidateClusterAccess_SingleUser(t *testing.T) {
}

func TestValidateClusterAccess_InvalidAccessMode(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
m := mocks.NewMockWorkspaceClient(t)
clustersAPI := m.GetMockClustersAPI()

Expand All @@ -43,7 +44,7 @@ func TestValidateClusterAccess_InvalidAccessMode(t *testing.T) {
}

func TestValidateClusterAccess_ClusterNotFound(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
m := mocks.NewMockWorkspaceClient(t)
clustersAPI := m.GetMockClustersAPI()

Expand Down Expand Up @@ -315,7 +316,7 @@ func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) {
}

func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "ssh_config")

Expand Down Expand Up @@ -349,7 +350,7 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
}

func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "ssh_config")

Expand Down Expand Up @@ -393,7 +394,7 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
}

func TestSetup_DoesNotOverrideExistingHost(t *testing.T) {
ctx := context.Background()
ctx := cmdio.MockDiscard(context.Background())
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "ssh_config")

Expand Down
137 changes: 137 additions & 0 deletions libs/cmdio/compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package cmdio

import (
"context"
"fmt"
"io"
"strings"

"github.com/manifoldco/promptui"
)

/*
Temporary compatibility layer for the progress logger interfaces.
*/

// Log is a compatibility layer for the progress logger interfaces.
// It writes the string representation of the stringer to the error writer.
func Log(ctx context.Context, str fmt.Stringer) {
LogString(ctx, str.String())
}

// LogString is a compatibility layer for the progress logger interfaces.
// It writes the string to the error writer.
func LogString(ctx context.Context, str string) {
c := fromContext(ctx)
_, _ = io.WriteString(c.err, str)
_, _ = io.WriteString(c.err, "\n")
}

// readLine reads a line from the reader and returns it without the trailing newline characters.
// It is unbuffered because cmdio's stdin is also unbuffered.
// If we were to add a [bufio.Reader] to the mix, we would need to update the other uses of the reader.
// Once cmdio's stdio is made to be buffered, this function can be removed.
func readLine(r io.Reader) (string, error) {
var b strings.Builder
buf := make([]byte, 1)
for {
n, err := r.Read(buf)
if n > 0 {
if buf[0] == '\n' {
break
}
if buf[0] != '\r' {
b.WriteByte(buf[0])
}
}
if err != nil {
if b.Len() == 0 {
return "", err
}
break
}
}
return b.String(), nil
}

// Ask is a compatibility layer for the progress logger interfaces.
// It prompts the user with a question and returns the answer.
func Ask(ctx context.Context, question, defaultVal string) (string, error) {
c := fromContext(ctx)

// Add default value to question prompt.
if defaultVal != "" {
question += fmt.Sprintf(` [%s]`, defaultVal)
}
question += `: `

// Print prompt.
_, err := io.WriteString(c.err, question)
if err != nil {
return "", err
}

// Read user input. Trim new line characters.
ans, err := readLine(c.in)
if err != nil {
return "", err
}

// Return default value if user just presses enter.
if ans == "" {
return defaultVal, nil
}

return ans, nil
}

// AskYesOrNo is a compatibility layer for the progress logger interfaces.
// It prompts the user with a question and returns the answer.
func AskYesOrNo(ctx context.Context, question string) (bool, error) {
ans, err := Ask(ctx, question+" [y/n]", "")
if err != nil {
return false, err
}
return ans == "y", nil
}

func splitAtLastNewLine(s string) (string, string) {
// Split at the newline character
if i := strings.LastIndex(s, "\n"); i != -1 {
return s[:i+1], s[i+1:]
}
// Return the original string if no newline found
return "", s
}

// AskSelect is a compatibility layer for the progress logger interfaces.
// It prompts the user with a question and returns the answer.
func AskSelect(ctx context.Context, question string, choices []string) (string, error) {
c := fromContext(ctx)

// Promptui does not support multiline prompts. So we split the question.
first, last := splitAtLastNewLine(question)
_, err := io.WriteString(c.err, first)
if err != nil {
return "", err
}

// Note: by default this prompt uses os.Stdin and os.Stdout.
// This is contrary to the rest of the original progress logger
// functions that write to stderr.
prompt := promptui.Select{
Label: last,
Items: choices,
HideHelp: true,
Templates: &promptui.SelectTemplates{
Label: "{{.}}: ",
Selected: last + ": {{.}}",
},
}

_, ans, err := prompt.Run()
if err != nil {
return "", err
}
return ans, nil
}
Loading
Loading