diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf20022..d954746 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: unit-tests on: push: - branches: [ master ] + branches: [ master, v2-wip ] pull_request: - branches: [ master ] + branches: [ master, v2-wip] jobs: diff --git a/expect.go b/expect.go index ee26707..d340777 100644 --- a/expect.go +++ b/expect.go @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr return fmt.Errorf("could not create expect options: %w", err) } - cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...) + cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...) if err != nil { return fmt.Errorf("could not add consumer: %w", err) } @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp select { case <-time.After(timeoutV): return fmt.Errorf("after %s: %w", timeoutV, TimeoutError) - case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{} - if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) { - return fmt.Errorf("cmd wait failed: %w", state.Err) + case err := <-waitChan(tt.cmd.Wait): + if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) { + return fmt.Errorf("cmd wait failed: %w", err) } - if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil { + if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil { return err } } diff --git a/expect_test.go b/expect_test.go index 54206dd..54ddb7c 100644 --- a/expect_test.go +++ b/expect_test.go @@ -85,7 +85,7 @@ func Test_ExpectCustom(t *testing.T) { []SetExpectOpt{OptExpectTimeout(time.Second)}, }, "", - TimeoutError, + PtyEOF, }, { "Custom error", @@ -167,7 +167,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) { }, []SetExpectOpt{OptExpectTimeout(time.Second)}, }, - TimeoutError, + PtyEOF, }, { "Custom error", @@ -194,7 +194,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) { } func Test_Expect_Timeout(t *testing.T) { - tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO"), false) + tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && sleep 1"), false) durations := []time.Duration{ 100 * time.Millisecond, 200 * time.Millisecond, diff --git a/helpers.go b/helpers.go index 9092a75..ef5b241 100644 --- a/helpers.go +++ b/helpers.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "os" - "os/exec" "strings" "time" ) @@ -23,20 +22,10 @@ type cmdExit struct { Err error } -// waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement -func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit { - exit := make(chan *cmdExit, 1) - go func() { - err := cmd.Wait() - exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err} - }() - return exit -} - func waitChan[T any](wait func() T) chan T { - done := make(chan T) + done := make(chan T, 1) go func() { - done <- wait() + wait() close(done) }() return done diff --git a/outputconsumer.go b/outputconsumer.go index 5e51c26..b5dc925 100644 --- a/outputconsumer.go +++ b/outputconsumer.go @@ -15,7 +15,6 @@ type outputConsumer struct { opts *OutputConsumerOpts isalive bool mutex *sync.Mutex - tt *TermTest } type OutputConsumerOpts struct { @@ -37,7 +36,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) { } } -func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer { +func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer { oc := &outputConsumer{ consume: consume, opts: &OutputConsumerOpts{ @@ -47,7 +46,6 @@ func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outp waiter: make(chan error, 1), isalive: true, mutex: &sync.Mutex{}, - tt: tt, } for _, optSetter := range opts { @@ -83,6 +81,23 @@ func (e *outputConsumer) Report(buffer []byte) (int, error) { return pos, err } +type errConsumerStopped struct { + reason error +} + +func (e errConsumerStopped) Error() string { + return fmt.Sprintf("consumer stopped, reason: %s", e.reason) +} + +func (e errConsumerStopped) Unwrap() error { + return e.reason +} + +func (e *outputConsumer) Stop(reason error) { + e.opts.Logger.Printf("stopping consumer, reason: %s\n", reason) + e.waiter <- errConsumerStopped{reason} +} + func (e *outputConsumer) wait() error { e.opts.Logger.Println("started waiting") defer e.opts.Logger.Println("stopped waiting") @@ -103,11 +118,5 @@ func (e *outputConsumer) wait() error { e.mutex.Lock() e.opts.Logger.Println("Encountered timeout") return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError) - case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{} - e.mutex.Lock() - if state.Err != nil { - e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error()) - } - return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode()) } } diff --git a/outputproducer.go b/outputproducer.go index 8c68b7a..178a06b 100644 --- a/outputproducer.go +++ b/outputproducer.go @@ -54,8 +54,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by for { o.opts.Logger.Println("listen: loop") if err := o.processNextRead(br, w, appendBuffer, size); err != nil { - if errors.Is(err, ptyEOF) { - o.opts.Logger.Println("listen: reached EOF") + if errors.Is(err, PtyEOF) { return nil } else { return fmt.Errorf("could not poll reader: %w", err) @@ -64,7 +63,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by } } -var ptyEOF = errors.New("pty closed") +var PtyEOF = errors.New("pty closed") func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer func([]byte, bool) error, size int) error { o.opts.Logger.Printf("processNextRead started with size: %d\n", size) @@ -78,6 +77,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer pathError := &fs.PathError{} if errors.Is(errRead, fs.ErrClosed) || errors.Is(errRead, io.EOF) || (runtime.GOOS == "linux" && errors.As(errRead, &pathError)) { isEOF = true + o.opts.Logger.Println("reached EOF") } } @@ -96,7 +96,8 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer if errRead != nil { if isEOF { - return errors.Join(errRead, ptyEOF) + o.closeConsumers(PtyEOF) + return errors.Join(errRead, PtyEOF) } return fmt.Errorf("could not read pty output: %w", errRead) } @@ -194,6 +195,19 @@ func (o *outputProducer) processDirtyOutput(output []byte, cursorPos int, cleanU return append(append(alreadyCleanedOutput, processedOutput...), unprocessedOutput...), processedCursorPos, newCleanUptoPos, nil } +func (o *outputProducer) closeConsumers(reason error) { + o.opts.Logger.Println("closing consumers") + defer o.opts.Logger.Println("closed consumers") + + o.mutex.Lock() + defer o.mutex.Unlock() + + for n := 0; n < len(o.consumers); n++ { + o.consumers[n].Stop(reason) + o.consumers = append(o.consumers[:n], o.consumers[n+1:]...) + } +} + func (o *outputProducer) flushConsumers() error { o.opts.Logger.Println("flushing consumers") defer o.opts.Logger.Println("flushed consumers") @@ -238,12 +252,12 @@ func (o *outputProducer) flushConsumers() error { return nil } -func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) { +func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) { o.opts.Logger.Printf("adding consumer") defer o.opts.Logger.Printf("added consumer") opts = append(opts, OptConsInherit(o.opts)) - listener := newOutputConsumer(tt, consume, opts...) + listener := newOutputConsumer(consume, opts...) o.consumers = append(o.consumers, listener) if err := o.flushConsumers(); err != nil { diff --git a/termtest.go b/termtest.go index d363ad0..fbcc2aa 100644 --- a/termtest.go +++ b/termtest.go @@ -23,8 +23,8 @@ type TermTest struct { ptmx pty.Pty outputProducer *outputProducer listenError chan error + waitError chan error opts *Opts - exited *cmdExit } type ErrorHandler func(*TermTest, error) error @@ -79,6 +79,7 @@ func New(cmd *exec.Cmd, opts ...SetOpt) (*TermTest, error) { cmd: cmd, outputProducer: newOutputProducer(optv), listenError: make(chan error, 1), + waitError: make(chan error, 1), opts: optv, } @@ -228,6 +229,7 @@ func (tt *TermTest) start() (rerr error) { tt.term = vt10x.New(vt10x.WithWriter(ptmx), vt10x.WithSize(tt.opts.Cols, tt.opts.Rows)) // Start listening for output + // We use a waitgroup here to ensure the listener is active before consumers are attached. wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -236,12 +238,18 @@ func (tt *TermTest) start() (rerr error) { err := tt.outputProducer.Listen(tt.ptmx, tt.term) tt.listenError <- err }() - wg.Wait() go func() { - tt.exited = <-waitForCmdExit(tt.cmd) + // We start waiting right away, because on Windows the PTY isn't closed until the process exits, which in turn + // can't happen unless we've told the pty we're ready for it to close. + // This of course isn't ideal, but until the pty library fixes the cross-platform inconsistencies we have to + // work around these limitations. + defer tt.opts.Logger.Printf("waitIndefinitely finished") + tt.waitError <- tt.waitIndefinitely() }() + wg.Wait() + return nil } @@ -252,13 +260,8 @@ func (tt *TermTest) Wait(timeout time.Duration) (rerr error) { tt.opts.Logger.Println("wait called") defer tt.opts.Logger.Println("wait closed") - errc := make(chan error, 1) - go func() { - errc <- tt.WaitIndefinitely() - }() - select { - case err := <-errc: + case err := <-tt.waitError: // WaitIndefinitely already invokes the expect error handler return err case <-time.After(timeout): @@ -324,28 +327,6 @@ func (tt *TermTest) SendCtrlC() { tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C } -// Exited returns a channel that sends the given termtest's command cmdExit info when available. -// This can be used within a select{} statement. -// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow -// switch cases with output consumers to handle unprocessed stdout. If there are no such cases -// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false. -func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit { - return waitChan(func() *cmdExit { - ticker := time.NewTicker(processExitPollInterval) - for { - select { - case <-ticker.C: - if tt.exited != nil { - if waitExtra { // allow sibling output consumer cases to handle their output - time.Sleep(processExitExtraWait) - } - return tt.exited - } - } - } - }) -} - func (tt *TermTest) errorHandler(rerr *error) { err := *rerr if err == nil { diff --git a/termtest_other.go b/termtest_other.go index 79cbc1b..21db974 100644 --- a/termtest_other.go +++ b/termtest_other.go @@ -12,7 +12,7 @@ func syscallErrorCode(err error) int { return -1 } -func (tt *TermTest) WaitIndefinitely() error { +func (tt *TermTest) waitIndefinitely() error { tt.opts.Logger.Println("WaitIndefinitely called") defer tt.opts.Logger.Println("WaitIndefinitely closed") diff --git a/termtest_windows.go b/termtest_windows.go index b767028..c01241f 100644 --- a/termtest_windows.go +++ b/termtest_windows.go @@ -16,10 +16,10 @@ func syscallErrorCode(err error) int { return 0 } -// WaitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself: +// waitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself: // https://github.com/photostorm/pty/issues/3 // Instead we wait for the process itself to exit, and after a grace period will shut down the pty. -func (tt *TermTest) WaitIndefinitely() error { +func (tt *TermTest) waitIndefinitely() error { tt.opts.Logger.Println("WaitIndefinitely called") defer tt.opts.Logger.Println("WaitIndefinitely closed")