Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr
return fmt.Errorf("could not create expect options: %w", err)
}

cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...)
cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...)
if err != nil {
return fmt.Errorf("could not add consumer: %w", err)
}
Expand Down Expand Up @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp
select {
case <-time.After(timeoutV):
return fmt.Errorf("after %s: %w", timeoutV, TimeoutError)
case err := <-waitChan(tt.cmd.Wait):
if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", err)
case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{}
if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", state.Err)
}
if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil {
if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil {
return err
}
}
Expand Down
8 changes: 4 additions & 4 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ type cmdExit struct {
}

// waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement
func waitForCmdExit(cmd *exec.Cmd) chan cmdExit {
exit := make(chan cmdExit, 1)
func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit {
exit := make(chan *cmdExit, 1)
go func() {
err := cmd.Wait()
exit <- cmdExit{ProcessState: cmd.ProcessState, Err: err}
exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err}
}()
return exit
}

func waitChan[T any](wait func() T) chan T {
done := make(chan T)
go func() {
wait()
done <- wait()
close(done)
}()
return done
Expand Down
10 changes: 9 additions & 1 deletion outputconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type outputConsumer struct {
opts *OutputConsumerOpts
isalive bool
mutex *sync.Mutex
tt *TermTest
}

type OutputConsumerOpts struct {
Expand All @@ -36,7 +37,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) {
}
}

func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer {
oc := &outputConsumer{
consume: consume,
opts: &OutputConsumerOpts{
Expand All @@ -46,6 +47,7 @@ func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
waiter: make(chan error, 1),
isalive: true,
mutex: &sync.Mutex{},
tt: tt,
}

for _, optSetter := range opts {
Expand Down Expand Up @@ -101,5 +103,11 @@ func (e *outputConsumer) wait() error {
e.mutex.Lock()
e.opts.Logger.Println("Encountered timeout")
return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError)
case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{}
e.mutex.Lock()
if state.Err != nil {
e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error())
}
return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode())
}
}
4 changes: 2 additions & 2 deletions outputproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ func (o *outputProducer) flushConsumers() error {
return nil
}

func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
o.opts.Logger.Printf("adding consumer")
defer o.opts.Logger.Printf("added consumer")

opts = append(opts, OptConsInherit(o.opts))
listener := newOutputConsumer(consume, opts...)
listener := newOutputConsumer(tt, consume, opts...)
o.consumers = append(o.consumers, listener)

if err := o.flushConsumers(); err != nil {
Expand Down
30 changes: 30 additions & 0 deletions termtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type TermTest struct {
outputProducer *outputProducer
listenError chan error
opts *Opts
exited *cmdExit
}

type ErrorHandler func(*TermTest, error) error
Expand All @@ -50,6 +51,9 @@ type SetOpt func(o *Opts) error
const DefaultCols = 140
const DefaultRows = 10

var processExitPollInterval = 10 * time.Millisecond
var processExitExtraWait = 500 * time.Millisecond

func NewOpts() *Opts {
return &Opts{
Logger: VoidLogger,
Expand Down Expand Up @@ -234,6 +238,10 @@ func (tt *TermTest) start() (rerr error) {
}()
wg.Wait()

go func() {
tt.exited = <-waitForCmdExit(tt.cmd)
}()

return nil
}

Expand Down Expand Up @@ -316,6 +324,28 @@ func (tt *TermTest) SendCtrlC() {
tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C
}

// Exited returns a channel that sends the given termtest's command cmdExit info when available.
// This can be used within a select{} statement.
// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow
// switch cases with output consumers to handle unprocessed stdout. If there are no such cases
// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false.
func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit {
return waitChan(func() *cmdExit {
ticker := time.NewTicker(processExitPollInterval)
for {
select {
case <-ticker.C:
if tt.exited != nil {
if waitExtra { // allow sibling output consumer cases to handle their output
time.Sleep(processExitExtraWait)
}
return tt.exited
}
}
}
})
}

func (tt *TermTest) errorHandler(rerr *error) {
err := *rerr
if err == nil {
Expand Down