Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 35 additions & 104 deletions cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,120 +106,68 @@ func TestPromptForConfirmation(t *testing.T) {
}()

for _, tc := range []struct {
desc string
f func(*testing.T, context.Context, chan promptResult)
desc string
f func() error
expected promptResult
}{
{"SIGINT", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

{"SIGINT", func() error {
syscall.Kill(syscall.Getpid(), syscall.SIGINT)

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after SIGINT")
case r := <-c:
assert.Check(t, !r.result)
assert.ErrorContains(t, r.err, "prompt terminated")
}
}},
{"no", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

return nil
}, promptResult{false, command.ErrPromptTerminated}},
{"no", func() error {
_, err := fmt.Fprint(promptWriter, "n\n")
assert.NilError(t, err)

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `n`")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
{"yes", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

return err
}, promptResult{false, nil}},
{"yes", func() error {
_, err := fmt.Fprint(promptWriter, "y\n")
assert.NilError(t, err)

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `y`")
case r := <-c:
assert.Check(t, r.result)
assert.NilError(t, r.err)
}
}},
{"any", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

return err
}, promptResult{true, nil}},
{"any", func() error {
_, err := fmt.Fprint(promptWriter, "a\n")
assert.NilError(t, err)

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `a`")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
{"with space", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

return err
}, promptResult{false, nil}},
{"with space", func() error {
_, err := fmt.Fprint(promptWriter, " y\n")
assert.NilError(t, err)

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input ` y`")
case r := <-c:
assert.Check(t, r.result)
assert.NilError(t, r.err)
}
}},
{"reader closed", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()

assert.NilError(t, promptReader.Close())

select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after promptReader was closed")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
return err
}, promptResult{true, nil}},
{"reader closed", func() error {
return promptReader.Close()
}, promptResult{false, nil}},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
buf.Reset()
promptReader, promptWriter = io.Pipe()

wroteHook := make(chan struct{}, 1)
defer close(wroteHook)
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
wroteHook <- struct{}{}
})

result := make(chan promptResult, 1)
defer close(result)
go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
result <- promptResult{r, err}
}()

// wait for the Prompt to write to the buffer
pollForPromptOutput(ctx, t, wroteHook)
drainChannel(ctx, wroteHook)
select {
case <-time.After(100 * time.Millisecond):
case <-wroteHook:
}

assert.NilError(t, bufioWriter.Flush())
assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]")

resultCtx, resultCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer resultCancel()
// wait for the Prompt to write to the buffer
drainChannel(ctx, wroteHook)

assert.NilError(t, tc.f())

tc.f(t, resultCtx, result)
select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout waiting for prompt result")
case r := <-result:
assert.Equal(t, r, tc.expected)
}
})
}
}
Expand All @@ -235,20 +183,3 @@ func drainChannel(ctx context.Context, ch <-chan struct{}) {
}
}()
}

func pollForPromptOutput(ctx context.Context, t *testing.T, wroteHook <-chan struct{}) {
t.Helper()

ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

for {
select {
case <-ctx.Done():
t.Fatal("Prompt output was not written to before ctx was cancelled")
return
case <-wroteHook:
return
}
}
}