diff --git a/cmd/server/server.go b/cmd/server/server.go index 6a7fa7f0..6d5cdec3 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -136,7 +136,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } - srv.StartSnapshotLoop(ctx) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { diff --git a/e2e/echo.go b/e2e/echo.go index 7388d5e1..09235393 100644 --- a/e2e/echo.go +++ b/e2e/echo.go @@ -89,11 +89,10 @@ func runEchoAgent(scriptPath string) { if entry.ThinkDurationMS > 0 { redrawTerminal(messages, true) spinnerCtx, spinnerCancel := context.WithCancel(ctx) - go runSpinner(spinnerCtx) + spinnerDone := runSpinner(spinnerCtx) time.Sleep(time.Duration(entry.ThinkDurationMS) * time.Millisecond) - if spinnerCancel != nil { - spinnerCancel() - } + spinnerCancel() + <-spinnerDone } messages = append(messages, st.ConversationMessage{ @@ -133,9 +132,10 @@ func runEchoAgent(scriptPath string) { if entry.ThinkDurationMS > 0 { redrawTerminal(messages, true) spinnerCtx, spinnerCancel := context.WithCancel(ctx) - go runSpinner(spinnerCtx) + spinnerDone := runSpinner(spinnerCtx) time.Sleep(time.Duration(entry.ThinkDurationMS) * time.Millisecond) spinnerCancel() + <-spinnerDone } messages = append(messages, st.ConversationMessage{ @@ -190,21 +190,26 @@ func cleanTerminalInput(input string) string { return strings.TrimSpace(input) } -func runSpinner(ctx context.Context) { - spinnerChars := []string{"|", "/", "-", "\\"} - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() - i := 0 - - for { - select { - case <-ticker.C: - fmt.Printf("\rThinking %s", spinnerChars[i%len(spinnerChars)]) - i++ - case <-ctx.Done(): - // Clear spinner on cancellation - fmt.Print("\r" + strings.Repeat(" ", 20) + "\r") - return +func runSpinner(ctx context.Context) <-chan struct{} { + done := make(chan struct{}) + go func() { + defer close(done) + spinnerChars := []string{"|", "/", "-", "\\"} + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + i := 0 + + for { + select { + case <-ticker.C: + fmt.Printf("\rThinking %s", spinnerChars[i%len(spinnerChars)]) + i++ + case <-ctx.Done(): + // Clear spinner on cancellation + fmt.Print("\r" + strings.Repeat(" ", 20) + "\r") + return + } } - } + }() + return done } diff --git a/e2e/echo_test.go b/e2e/echo_test.go index fbc1efc0..765521cf 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -133,14 +133,11 @@ func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agenta cwd, err := os.Getwd() require.NoError(t, err, "Failed to get current working directory") binaryPath = filepath.Join(cwd, "..", "out", "agentapi") - _, err = os.Stat(binaryPath) - if err != nil { - t.Logf("Building binary at %s", binaryPath) - buildCmd := exec.CommandContext(ctx, "go", "build", "-o", binaryPath, ".") - buildCmd.Dir = filepath.Join(cwd, "..") - t.Logf("run: %s", buildCmd.String()) - require.NoError(t, buildCmd.Run(), "Failed to build binary") - } + t.Logf("Building binary at %s", binaryPath) + buildCmd := exec.CommandContext(ctx, "go", "build", "-o", binaryPath, ".") + buildCmd.Dir = filepath.Join(cwd, "..") + t.Logf("run: %s", buildCmd.String()) + require.NoError(t, buildCmd.Run(), "Failed to build binary") } serverPort, err := getFreePort() @@ -254,7 +251,11 @@ func waitAgentAPIStable(ctx context.Context, t testing.TB, apiClient *agentapisd return nil } } else { - t.Logf("Got %T event", evt) + var sb strings.Builder + if err := json.NewEncoder(&sb).Encode(evt); err != nil { + t.Logf("Failed to encode event: %v", err) + } + t.Logf("Got event: %s", sb.String()) } case err := <-errs: return fmt.Errorf("read events: %w", err) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index cb13a29e..e43315bf 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -41,7 +41,7 @@ type Server struct { srv *http.Server mu sync.RWMutex logger *slog.Logger - conversation *st.PTYConversation + conversation st.Conversation agentio *termexec.Process agentType mf.AgentType emitter *EventEmitter @@ -244,6 +244,14 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } + emitter := NewEventEmitter(1024) + + // Format initial prompt into message parts if provided + var initialPrompt []st.MessagePart + if config.InitialPrompt != "" { + initialPrompt = FormatMessage(config.AgentType, config.InitialPrompt) + } + conversation := st.NewPTY(ctx, st.PTYConversationConfig{ AgentType: config.AgentType, AgentIO: config.Process, @@ -253,9 +261,17 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { FormatMessage: formatMessage, ReadyForInitialPrompt: isAgentReadyForInitialPrompt, FormatToolCall: formatToolCall, - Logger: logger, - }, config.InitialPrompt) - emitter := NewEventEmitter(1024) + InitialPrompt: initialPrompt, + // OnSnapshot uses a callback rather than passing the emitter directly + // to keep the screentracker package decoupled from httpapi concerns. + // This preserves clean package boundaries and avoids import cycles. + OnSnapshot: func(status st.ConversationStatus, messages []st.ConversationMessage, screen string) { + emitter.UpdateStatusAndEmitChanges(status, config.AgentType) + emitter.UpdateMessagesAndEmitChanges(messages) + emitter.UpdateScreenAndEmitChanges(screen) + }, + Logger: logger, + }) // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") @@ -281,6 +297,16 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { // Register API routes s.registerRoutes() + // Start the conversation polling loop if we have a process. + // Process is nil only when --print-openapi is used (no agent runs). + // The process is already running at this point - termexec.StartProcess() + // blocks until the PTY is created and the process is active. Agent + // readiness (waiting for the prompt) is handled asynchronously inside + // conversation.Start() via ReadyForInitialPrompt. + if config.Process != nil { + s.conversation.Start(ctx) + } + return s, nil } @@ -336,38 +362,6 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { next(ctx) } -func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.Start(ctx) - go func() { - ticker := s.clock.NewTicker(snapshotInterval) - defer ticker.Stop() - for { - currentStatus := s.conversation.Status() - - // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") - } - } - s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) - s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Text()) - - select { - case <-ctx.Done(): - return - case <-ticker.C: - } - } - }() -} - // registerRoutes sets up all API endpoints func (s *Server) registerRoutes() { // GET /status endpoint diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 38d8e409..551ea50d 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -43,6 +43,12 @@ func (p MessagePartText) String() string { return p.Content } +// outboundMessage wraps a message to be sent with its error channel +type outboundMessage struct { + parts []MessagePart + errCh chan error +} + // PTYConversationConfig is the configuration for a PTYConversation. type PTYConversationConfig struct { AgentType msgfmt.AgentType @@ -56,17 +62,15 @@ type PTYConversationConfig struct { // Function to format the messages received from the agent // userInput is the last user message FormatMessage func(message string, userInput string) string - // SkipWritingMessage skips the writing of a message to the agent. - // This is used in tests - SkipWritingMessage bool - // SkipSendMessageStatusCheck skips the check for whether the message can be sent. - // This is used in tests - SkipSendMessageStatusCheck bool // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt ReadyForInitialPrompt func(message string) bool // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) - Logger *slog.Logger + // InitialPrompt is the initial prompt to send to the agent once ready + InitialPrompt []MessagePart + // OnSnapshot is called after each snapshot with current status, messages, and screen content + OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) + Logger *slog.Logger } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -90,19 +94,25 @@ type PTYConversation struct { screenBeforeLastUserMessage string lock sync.Mutex - // InitialPrompt is the initial prompt passed to the agent - InitialPrompt string - // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents - InitialPromptSent bool - // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt - ReadyForInitialPrompt bool + // outboundQueue holds messages waiting to be sent to the agent + outboundQueue chan outboundMessage + // stableSignal is used by the snapshot loop to signal the send loop + // when the agent is stable and there are items in the outbound queue. + stableSignal chan struct{} // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // initialPromptReady is closed when ReadyForInitialPrompt returns true. + // This is checked by a separate goroutine to avoid calling ReadyForInitialPrompt on every tick. + initialPromptReady chan struct{} } var _ Conversation = &PTYConversation{} -func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string) *PTYConversation { +// errInitialPromptReady is a sentinel used to stop the readiness TickerFunc +// after ReadyForInitialPrompt returns true. +var errInitialPromptReady = xerrors.New("initial prompt ready") + +func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { if cfg.Clock == nil { cfg.Clock = quartz.NewReal() } @@ -118,33 +128,85 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string Time: cfg.Clock.Now(), }, }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), toolCallMessageSet: make(map[string]bool), + initialPromptReady: make(chan struct{}), + } + // If we have an initial prompt, enqueue it + if len(cfg.InitialPrompt) > 0 { + c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} + } + if c.cfg.OnSnapshot == nil { + c.cfg.OnSnapshot = func(ConversationStatus, []ConversationMessage, string) {} + } + if c.cfg.ReadyForInitialPrompt == nil { + c.cfg.ReadyForInitialPrompt = func(string) bool { return true } } return c } func (c *PTYConversation) Start(ctx context.Context) { + // Initial prompt readiness loop - polls ReadyForInitialPrompt until it returns true, + // then closes initialPromptReady and exits. This avoids calling ReadyForInitialPrompt + // on every snapshot tick. + c.cfg.Clock.TickerFunc(ctx, 100*time.Millisecond, func() error { + screen := c.cfg.AgentIO.ReadScreen() + if c.cfg.ReadyForInitialPrompt(screen) { + close(c.initialPromptReady) + return errInitialPromptReady + } + return nil + }, "readiness") + + // Snapshot loop + c.cfg.Clock.TickerFunc(ctx, c.cfg.SnapshotInterval, func() error { + c.lock.Lock() + screen := c.cfg.AgentIO.ReadScreen() + c.snapshotLocked(screen) + status := c.statusLocked() + messages := c.messagesLocked() + + // Signal send loop if agent is ready and queue has items. + // We check readiness independently of statusLocked() because + // statusLocked() returns "changing" when queue has items. + isReady := false + select { + case <-c.initialPromptReady: + isReady = true + default: + } + if isReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { + select { + case c.stableSignal <- struct{}{}: + default: + // Signal already pending + } + } + c.lock.Unlock() + + c.cfg.OnSnapshot(status, messages, screen) + return nil + }, "snapshot") + + // Send loop - primary call site for sendLocked() in production go func() { - ticker := c.cfg.Clock.NewTicker(c.cfg.SnapshotInterval) - defer ticker.Stop() for { select { case <-ctx.Done(): return - case <-ticker.C: - // It's important that we hold the lock while reading the screen. - // There's a race condition that occurs without it: - // 1. The screen is read - // 2. Independently, Send is called and takes the lock. - // 3. snapshotLocked is called and waits on the lock. - // 4. Send modifies the terminal state, releases the lock - // 5. snapshotLocked adds a snapshot from a stale screen - c.lock.Lock() - screen := c.cfg.AgentIO.ReadScreen() - c.snapshotLocked(screen) - c.lock.Unlock() + case <-c.stableSignal: + select { + case <-ctx.Done(): + return + case msg := <-c.outboundQueue: + err := c.sendMessage(ctx, msg.parts...) + if msg.errCh != nil { + msg.errCh <- err + } + default: + c.cfg.Logger.Error("received stable signal but outbound queue is empty") + } } } }() @@ -198,16 +260,6 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1].Id = len(c.messages) - 1 } -// Snapshot writes the current screen snapshot to the snapshot buffer. -// ONLY TO BE USED FOR TESTING PURPOSES. -// TODO(Cian): This method can be removed by mocking AgentIO. -func (c *PTYConversation) Snapshot(screen string) { - c.lock.Lock() - defer c.lock.Unlock() - - c.snapshotLocked(screen) -} - // caller MUST hold c.lock func (c *PTYConversation) snapshotLocked(screen string) { snapshot := screenSnapshot{ @@ -219,35 +271,62 @@ func (c *PTYConversation) snapshotLocked(screen string) { } func (c *PTYConversation) Send(messageParts ...MessagePart) error { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { - return ErrMessageValidationChanging - } - + // Validate message content before enqueueing var sb strings.Builder for _, part := range messageParts { sb.WriteString(part.String()) } message := sb.String() if message != msgfmt.TrimWhitespace(message) { - // msgfmt formatting functions assume this return ErrMessageValidationWhitespace } if message == "" { - // writeMessageWithConfirmation requires a non-empty message return ErrMessageValidationEmpty } + c.lock.Lock() + if c.statusLocked() != ConversationStatusStable { + c.lock.Unlock() + return ErrMessageValidationChanging + } + c.lock.Unlock() + + errCh := make(chan error, 1) + c.outboundQueue <- outboundMessage{parts: messageParts, errCh: errCh} + return <-errCh +} + +// sendMessage sends a message to the agent. It acquires and releases c.lock +// around the parts that access shared state, but releases it during +// writeStabilize to avoid blocking the snapshot loop. +func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...MessagePart) error { + var sb strings.Builder + for _, part := range messageParts { + sb.WriteString(part.String()) + } + message := sb.String() + + c.lock.Lock() screenBeforeMessage := c.cfg.AgentIO.ReadScreen() now := c.cfg.Clock.Now() c.updateLastAgentMessageLocked(screenBeforeMessage, now) + c.lock.Unlock() - if err := c.writeStabilize(context.Background(), messageParts...); err != nil { + if err := c.writeStabilize(ctx, messageParts...); err != nil { return xerrors.Errorf("failed to send message: %w", err) } + c.lock.Lock() + // Re-apply the pre-send agent message from the screen captured before + // the write. While the lock was released during writeStabilize, the + // snapshot loop continued taking snapshots and calling + // updateLastAgentMessageLocked with whatever was on screen at each + // tick (typically echoed user input or intermediate terminal state). + // Those updates corrupt the agent message for this turn. Restoring it + // here ensures the conversation history is correct. The next line sets + // screenBeforeLastUserMessage so the *next* agent message will be + // diffed relative to the pre-send screen. + c.updateLastAgentMessageLocked(screenBeforeMessage, now) c.screenBeforeLastUserMessage = screenBeforeMessage c.messages = append(c.messages, ConversationMessage{ Id: len(c.messages), @@ -255,14 +334,12 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { Role: ConversationRoleUser, Time: now, }) + c.lock.Unlock() return nil } // writeStabilize writes messageParts to the screen and waits for the screen to stabilize after the message is written. func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...MessagePart) error { - if c.cfg.SkipWritingMessage { - return nil - } screenBeforeMessage := c.cfg.AgentIO.ReadScreen() for _, part := range messageParts { if err := part.Do(c.cfg.AgentIO); err != nil { @@ -274,6 +351,7 @@ func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...Me Timeout: 15 * time.Second, MinInterval: 50 * time.Millisecond, InitialWait: true, + Clock: c.cfg.Clock, }, func() (bool, error) { screen := c.cfg.AgentIO.ReadScreen() if screen != screenBeforeMessage { @@ -296,6 +374,7 @@ func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...Me if err := util.WaitFor(ctx, util.WaitTimeout{ Timeout: 15 * time.Second, MinInterval: 25 * time.Millisecond, + Clock: c.cfg.Clock, }, func() (bool, error) { // we don't want to spam additional carriage returns because the agent may process them // (aider does this), but we do want to retry sending one if nothing's @@ -328,6 +407,21 @@ func (c *PTYConversation) Status() ConversationStatus { return c.statusLocked() } +// isScreenStableLocked returns true if the screen content has been stable +// for the required number of snapshots. Caller MUST hold c.lock. +func (c *PTYConversation) isScreenStableLocked() bool { + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) < c.stableSnapshotsThreshold { + return false + } + for i := 1; i < len(snapshots); i++ { + if snapshots[0].screen != snapshots[i].screen { + return false + } + } + return true +} + // caller MUST hold c.lock func (c *PTYConversation) statusLocked() ConversationStatus { // sanity checks @@ -350,17 +444,13 @@ func (c *PTYConversation) statusLocked() ConversationStatus { return ConversationStatusInitializing } - for i := 1; i < len(snapshots); i++ { - if snapshots[0].screen != snapshots[i].screen { - return ConversationStatusChanging - } + if !c.isScreenStableLocked() { + return ConversationStatusChanging } - if !c.InitialPromptSent && !c.ReadyForInitialPrompt { - if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { - c.ReadyForInitialPrompt = true - return ConversationStatusStable - } + // Handle initial prompt readiness: report "changing" until the queue is drained + // to avoid the status flipping "changing" -> "stable" -> "changing" + if len(c.outboundQueue) > 0 { return ConversationStatusChanging } @@ -371,6 +461,11 @@ func (c *PTYConversation) Messages() []ConversationMessage { c.lock.Lock() defer c.lock.Unlock() + return c.messagesLocked() +} + +// messagesLocked returns a copy of messages. Caller MUST hold c.lock. +func (c *PTYConversation) messagesLocked() []ConversationMessage { result := make([]ConversationMessage, len(c.messages)) copy(result, c.messages) return result diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index e2903227..eaa4a69e 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -3,48 +3,136 @@ package screentracker_test import ( "context" "fmt" + "io" + "log/slog" + "sync" "testing" "time" "github.com/coder/quartz" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" st "github.com/coder/agentapi/lib/screentracker" ) -type statusTestStep struct { - snapshot string - status st.ConversationStatus -} -type statusTestParams struct { - cfg st.PTYConversationConfig - steps []statusTestStep -} +const testTimeout = 10 * time.Second +// testAgent is a goroutine-safe mock implementation of AgentIO. type testAgent struct { - st.AgentIO - screen string + mu sync.Mutex + screen string + // onWrite is called during Write to simulate the agent reacting to + // terminal input (e.g., changing the screen), which unblocks + // writeStabilize's polling loops. + onWrite func(data []byte) } func (a *testAgent) ReadScreen() string { + a.mu.Lock() + defer a.mu.Unlock() return a.screen } func (a *testAgent) Write(data []byte) (int, error) { - return 0, nil + a.mu.Lock() + defer a.mu.Unlock() + if a.onWrite != nil { + a.onWrite(data) + } + return len(data), nil +} + +func (a *testAgent) setScreen(s string) { + a.mu.Lock() + defer a.mu.Unlock() + a.screen = s +} + +// advanceFor is a shorthand for advanceUntil with a time-based condition. +func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { + t.Helper() + target := mClock.Now().Add(total) + advanceUntil(ctx, t, mClock, func() bool { return !mClock.Now().Before(target) }) +} + +// advanceUntil advances the mock clock one event at a time until done returns +// true. Because the snapshot TickerFunc is always pending and WaitFor reuses a +// single timer via Reset, there is always at least one event to advance. +func advanceUntil(ctx context.Context, t *testing.T, mClock *quartz.Mock, done func() bool) { + t.Helper() + for !done() { + select { + case <-ctx.Done(): + t.Fatal("context cancelled waiting for condition") + default: + } + _, w := mClock.AdvanceNext() + w.MustWait(ctx) + } +} + +// sendAndAdvance calls Send() in a goroutine and advances the mock clock until +// Send completes. +func sendAndAdvance(ctx context.Context, t *testing.T, c *st.PTYConversation, mClock *quartz.Mock, parts ...st.MessagePart) { + t.Helper() + errCh := make(chan error, 1) + go func() { + errCh <- c.Send(parts...) + }() + advanceUntil(ctx, t, mClock, func() bool { + select { + case err := <-errCh: + require.NoError(t, err) + return true + default: + return false + } + }) +} + +func assertMessages(t *testing.T, c *st.PTYConversation, expected []st.ConversationMessage) { + t.Helper() + actual := c.Messages() + for i := range actual { + require.False(t, actual[i].Time.IsZero(), "message %d Time should be non-zero", i) + actual[i].Time = time.Time{} + } + require.Equal(t, expected, actual) +} + +type statusTestStep struct { + snapshot string + status st.ConversationStatus +} +type statusTestParams struct { + cfg st.PTYConversationConfig + steps []statusTestStep } func statusTest(t *testing.T, params statusTestParams) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) t.Run(fmt.Sprintf("interval-%s,stability_length-%s", params.cfg.SnapshotInterval, params.cfg.ScreenStabilityLength), func(t *testing.T) { - if params.cfg.Clock == nil { - params.cfg.Clock = quartz.NewReal() + mClock := quartz.NewMock(t) + params.cfg.Clock = mClock + agent := &testAgent{} + if params.cfg.AgentIO != nil { + if a, ok := params.cfg.AgentIO.(*testAgent); ok { + agent = a + } } - c := st.NewPTY(ctx, params.cfg, "") + params.cfg.AgentIO = agent + params.cfg.Logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + + c := st.NewPTY(ctx, params.cfg) + c.Start(ctx) + assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { - c.Snapshot(step.snapshot) + agent.setScreen(step.snapshot) + advanceFor(ctx, t, mClock, params.cfg.SnapshotInterval) assert.Equal(t, step.status, c.Status(), "step %d", i) } }) @@ -114,58 +202,63 @@ func TestConversation(t *testing.T) { } func TestMessages(t *testing.T) { - now := time.Now() - agentMsg := func(id int, msg string) st.ConversationMessage { - return st.ConversationMessage{ - Id: id, - Message: msg, - Role: st.ConversationRoleAgent, - Time: now, - } - } - userMsg := func(id int, msg string) st.ConversationMessage { - return st.ConversationMessage{ - Id: id, - Message: msg, - Role: st.ConversationRoleUser, - Time: now, + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + + // newConversation creates a started conversation with a mock clock and + // testAgent. Tests that Send() messages must use sendAndAdvance. + newConversation := func(ctx context.Context, t *testing.T, opts ...func(*st.PTYConversationConfig)) (*st.PTYConversation, *testAgent, *quartz.Mock) { + t.Helper() + + writeCounter := 0 + agent := &testAgent{} + // Default onWrite: each write produces a unique screen so that + // writeStabilize can detect screen changes. + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) } - } - sendMsg := func(c *st.PTYConversation, msg string) error { - return c.Send(st.MessagePartText{Content: msg}) - } - newConversation := func(opts ...func(*st.PTYConversationConfig)) *st.PTYConversation { mClock := quartz.NewMock(t) mClock.Set(now) cfg := st.PTYConversationConfig{ Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 2 * time.Second, - SkipWritingMessage: true, - SkipSendMessageStatusCheck: true, + AgentIO: agent, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } for _, opt := range opts { opt(&cfg) } - return st.NewPTY(context.Background(), cfg, "") + if a, ok := cfg.AgentIO.(*testAgent); ok { + agent = a + } + + c := st.NewPTY(ctx, cfg) + c.Start(ctx) + + return c, agent, mClock } + // threshold = 3 (200ms / 100ms = 2, + 1 = 3) + const threshold = 3 + const interval = 100 * time.Millisecond + t.Run("messages are copied", func(t *testing.T) { - c := newConversation() + c, _, _ := newConversation(context.Background(), t) messages := c.Messages() - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, ""), - }, messages) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "", Role: st.ConversationRoleAgent}, + }) messages[0].Message = "modification" - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, ""), - }, c.Messages()) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "", Role: st.ConversationRoleAgent}, + }) }) t.Run("whitespace-padding", func(t *testing.T) { - c := newConversation() + c, _, _ := newConversation(context.Background(), t) for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} { err := c.Send(st.MessagePartText{Content: msg}) assert.ErrorIs(t, err, st.ErrMessageValidationWhitespace) @@ -173,335 +266,342 @@ func TestMessages(t *testing.T) { }) t.Run("no-change-no-message-update", func(t *testing.T) { - mClock := quartz.NewMock(t) - mClock.Set(now) - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.Clock = mClock - }) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + c, agent, mClock := newConversation(ctx, t) - c.Snapshot("1") + agent.setScreen("1") + advanceFor(ctx, t, mClock, interval) msgs := c.Messages() - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1"), - }, msgs) - mClock.Set(now.Add(1 * time.Second)) - c.Snapshot("1") + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "1", Role: st.ConversationRoleAgent}, + }) + + advanceFor(ctx, t, mClock, interval) assert.Equal(t, msgs, c.Messages()) }) t.Run("tracking messages", func(t *testing.T) { - agent := &testAgent{} - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.AgentIO = agent + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + c, agent, mClock := newConversation(ctx, t) + + // Agent message is recorded when the first snapshot is taken. + agent.setScreen("1") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "1", Role: st.ConversationRoleAgent}, + }) + + // Agent message is updated when the screen changes. + agent.setScreen("2") + advanceFor(ctx, t, mClock, interval) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "2", Role: st.ConversationRoleAgent}, + }) + + // Fill to stable so Send can proceed (screen is "2"). + agent.setScreen("2") + advanceFor(ctx, t, mClock, interval*threshold) + + // User message is recorded. + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "3"}) + + // After send, screen is dirty from writeStabilize. Set to "4" and stabilize. + agent.setScreen("4") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "2", Role: st.ConversationRoleAgent}, + {Id: 1, Message: "3", Role: st.ConversationRoleUser}, + {Id: 2, Message: "4", Role: st.ConversationRoleAgent}, + }) + + // Agent message is updated when the screen changes before a user message. + agent.setScreen("5") + advanceFor(ctx, t, mClock, interval*threshold) + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "6"}) + + agent.setScreen("7") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "2", Role: st.ConversationRoleAgent}, + {Id: 1, Message: "3", Role: st.ConversationRoleUser}, + {Id: 2, Message: "5", Role: st.ConversationRoleAgent}, + {Id: 3, Message: "6", Role: st.ConversationRoleUser}, + {Id: 4, Message: "7", Role: st.ConversationRoleAgent}, }) - // agent message is recorded when the first snapshot is added - c.Snapshot("1") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1"), - }, c.Messages()) - - // agent message is updated when the screen changes - c.Snapshot("2") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "2"), - }, c.Messages()) - - // user message is recorded - agent.screen = "2" - assert.NoError(t, sendMsg(c, "3")) - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "2"), - userMsg(1, "3"), - }, c.Messages()) - - // agent message is added after a user message - c.Snapshot("4") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "2"), - userMsg(1, "3"), - agentMsg(2, "4"), - }, c.Messages()) - - // agent message is updated when the screen changes before a user message - agent.screen = "5" - assert.NoError(t, sendMsg(c, "6")) - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "2"), - userMsg(1, "3"), - agentMsg(2, "5"), - userMsg(3, "6"), - }, c.Messages()) - - // conversation status is changing right after a user message - c.Snapshot("7") - c.Snapshot("7") - c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) - agent.screen = "7" - assert.NoError(t, sendMsg(c, "8")) - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "2"), - userMsg(1, "3"), - agentMsg(2, "5"), - userMsg(3, "6"), - agentMsg(4, "7"), - userMsg(5, "8"), - }, c.Messages()) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) - // conversation status is back to stable after a snapshot that - // doesn't change the screen - c.Snapshot("7") + // Send another message. + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "8"}) + + // After filling to stable, messages and status are correct. + agent.setScreen("7") + advanceFor(ctx, t, mClock, interval*threshold) assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("tracking messages overlap", func(t *testing.T) { - agent := &testAgent{} - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.AgentIO = agent + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + c, agent, mClock := newConversation(ctx, t) + + // Common overlap between screens is removed after a user message. + agent.setScreen("1") + advanceFor(ctx, t, mClock, interval*threshold) + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "2"}) + agent.setScreen("1\n3") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "1", Role: st.ConversationRoleAgent}, + {Id: 1, Message: "2", Role: st.ConversationRoleUser}, + {Id: 2, Message: "3", Role: st.ConversationRoleAgent}, }) - // common overlap between screens is removed after a user message - c.Snapshot("1") - agent.screen = "1" - assert.NoError(t, sendMsg(c, "2")) - c.Snapshot("1\n3") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1"), - userMsg(1, "2"), - agentMsg(2, "3"), - }, c.Messages()) - - agent.screen = "1\n3x" - assert.NoError(t, sendMsg(c, "4")) - c.Snapshot("1\n3x\n5") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1"), - userMsg(1, "2"), - agentMsg(2, "3x"), - userMsg(3, "4"), - agentMsg(4, "5"), - }, c.Messages()) + agent.setScreen("1\n3x") + advanceFor(ctx, t, mClock, interval*threshold) + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "4"}) + agent.setScreen("1\n3x\n5") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "1", Role: st.ConversationRoleAgent}, + {Id: 1, Message: "2", Role: st.ConversationRoleUser}, + {Id: 2, Message: "3x", Role: st.ConversationRoleAgent}, + {Id: 3, Message: "4", Role: st.ConversationRoleUser}, + {Id: 4, Message: "5", Role: st.ConversationRoleAgent}, + }) }) t.Run("format-message", func(t *testing.T) { - agent := &testAgent{} - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.AgentIO = agent + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + c, agent, mClock := newConversation(ctx, t, func(cfg *st.PTYConversationConfig) { cfg.FormatMessage = func(message string, userInput string) string { return message + " " + userInput } }) - agent.screen = "1" - assert.NoError(t, sendMsg(c, "2")) - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1 "), - userMsg(1, "2"), - }, c.Messages()) - agent.screen = "x" - c.Snapshot("x") - assert.Equal(t, []st.ConversationMessage{ - agentMsg(0, "1 "), - userMsg(1, "2"), - agentMsg(2, "x 2"), - }, c.Messages()) + + // Fill to stable with screen "1", then send. + agent.setScreen("1") + advanceFor(ctx, t, mClock, interval*threshold) + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "2"}) + + // After send, set screen to "x" and take snapshots for new agent message. + agent.setScreen("x") + advanceFor(ctx, t, mClock, interval*threshold) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "1 ", Role: st.ConversationRoleAgent}, + {Id: 1, Message: "2", Role: st.ConversationRoleUser}, + {Id: 2, Message: "x 2", Role: st.ConversationRoleAgent}, + }) }) - t.Run("format-message", func(t *testing.T) { - agent := &testAgent{} - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.AgentIO = agent + t.Run("format-message-initial", func(t *testing.T) { + c, _, _ := newConversation(context.Background(), t, func(cfg *st.PTYConversationConfig) { cfg.FormatMessage = func(message string, userInput string) string { return "formatted" } }) - assert.Equal(t, []st.ConversationMessage{ - { - Id: 0, - Message: "", - Role: st.ConversationRoleAgent, - Time: now, - }, - }, c.Messages()) + assertMessages(t, c, []st.ConversationMessage{ + {Id: 0, Message: "", Role: st.ConversationRoleAgent}, + }) }) t.Run("send-message-status-check", func(t *testing.T) { - c := newConversation(func(cfg *st.PTYConversationConfig) { - cfg.SkipSendMessageStatusCheck = false - cfg.SnapshotInterval = 1 * time.Second - cfg.ScreenStabilityLength = 2 * time.Second - cfg.AgentIO = &testAgent{} - }) - assert.ErrorIs(t, sendMsg(c, "1"), st.ErrMessageValidationChanging) - for range 3 { - c.Snapshot("1") + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + c, agent, mClock := newConversation(ctx, t) + + sendMsg := func(msg string) error { + return c.Send(st.MessagePartText{Content: msg}) } - assert.NoError(t, sendMsg(c, "4")) - c.Snapshot("2") - assert.ErrorIs(t, sendMsg(c, "5"), st.ErrMessageValidationChanging) + + // Status is initializing, send should fail. + assert.ErrorIs(t, sendMsg("1"), st.ErrMessageValidationChanging) + + // Fill to stable. + agent.setScreen("1") + advanceFor(ctx, t, mClock, interval*threshold) + assert.Equal(t, st.ConversationStatusStable, c.Status()) + + // Now send should succeed. + sendAndAdvance(ctx, t, c, mClock, st.MessagePartText{Content: "4"}) + + // After send, screen is dirty. Set to "2" (different from "1") so status is changing. + agent.setScreen("2") + advanceFor(ctx, t, mClock, interval) + assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.ErrorIs(t, sendMsg("5"), st.ErrMessageValidationChanging) }) t.Run("send-message-empty-message", func(t *testing.T) { - c := newConversation() - assert.ErrorIs(t, sendMsg(c, ""), st.ErrMessageValidationEmpty) + c, _, _ := newConversation(context.Background(), t) + assert.ErrorIs(t, c.Send(st.MessagePartText{Content: ""}), st.ErrMessageValidationEmpty) }) } func TestInitialPromptReadiness(t *testing.T) { - now := time.Now() + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) t.Run("agent not ready - status remains changing", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) mClock := quartz.NewMock(t) - mClock.Set(now) + agent := &testAgent{screen: "loading..."} cfg := st.PTYConversationConfig{ Clock: mClock, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, - AgentIO: &testAgent{screen: "loading..."}, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { return message == "ready" }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, + Logger: discardLogger, } - c := st.NewPTY(context.Background(), cfg, "initial prompt here") - // Fill buffer with stable snapshots, but agent is not ready - c.Snapshot("loading...") + c := st.NewPTY(ctx, cfg) + c.Start(ctx) - // Even though screen is stable, status should be changing because agent is not ready + // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). + advanceFor(ctx, t, mClock, 1*time.Second) + + // Even though screen is stable, status should be changing because + // the initial prompt is still in the outbound queue. assert.Equal(t, st.ConversationStatusChanging, c.Status()) - assert.False(t, c.ReadyForInitialPrompt) - assert.False(t, c.InitialPromptSent) }) - t.Run("agent becomes ready - status changes to stable", func(t *testing.T) { + t.Run("agent becomes ready - status stays changing until initial prompt sent", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) mClock := quartz.NewMock(t) - mClock.Set(now) + agent := &testAgent{screen: "loading..."} cfg := st.PTYConversationConfig{ Clock: mClock, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, - AgentIO: &testAgent{screen: "loading..."}, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { return message == "ready" }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, + Logger: discardLogger, } - c := st.NewPTY(context.Background(), cfg, "initial prompt here") - // Agent not ready initially - c.Snapshot("loading...") + c := st.NewPTY(ctx, cfg) + c.Start(ctx) + + // Agent not ready initially. + advanceFor(ctx, t, mClock, 1*time.Second) assert.Equal(t, st.ConversationStatusChanging, c.Status()) - // Agent becomes ready - c.Snapshot("ready") - assert.Equal(t, st.ConversationStatusStable, c.Status()) - assert.True(t, c.ReadyForInitialPrompt) - assert.False(t, c.InitialPromptSent) + // Agent becomes ready, but status stays "changing" because the + // initial prompt is still in the outbound queue. + agent.setScreen("ready") + advanceFor(ctx, t, mClock, 1*time.Second) + assert.Equal(t, st.ConversationStatusChanging, c.Status()) }) - t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) { + t.Run("initial prompt lifecycle - status stays changing until sent", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) mClock := quartz.NewMock(t) - mClock.Set(now) agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 0, - AgentIO: agent, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { return message == "ready" }, - SkipWritingMessage: true, - SkipSendMessageStatusCheck: true, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, + Logger: discardLogger, } - c := st.NewPTY(context.Background(), cfg, "initial prompt here") - - // Initial state: ReadyForInitialPrompt should be false - c.Snapshot("loading...") - assert.False(t, c.ReadyForInitialPrompt, "should start as false") - assert.False(t, c.InitialPromptSent) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) - // Agent becomes ready: ReadyForInitialPrompt should become true - agent.screen = "ready" - c.Snapshot("ready") - assert.Equal(t, st.ConversationStatusStable, c.Status()) - assert.True(t, c.ReadyForInitialPrompt, "should become true when ready") - assert.False(t, c.InitialPromptSent) + c := st.NewPTY(ctx, cfg) + c.Start(ctx) - // Send the initial prompt - assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) + // Status is "changing" while waiting for readiness. + advanceFor(ctx, t, mClock, 1*time.Second) + assert.Equal(t, st.ConversationStatusChanging, c.Status()) - // After sending initial prompt: ReadyForInitialPrompt should be set back to false - // (simulating what happens in the actual server code) - c.InitialPromptSent = true - c.ReadyForInitialPrompt = false + // Agent becomes ready. The readiness loop detects this, the snapshot + // loop sees queue + stable + ready and signals the send loop. + // writeStabilize runs with onWrite changing the screen, so it completes. + agent.setScreen("ready") + // Drive clock until the initial prompt is sent (queue drains). + advanceUntil(ctx, t, mClock, func() bool { + return len(c.Messages()) >= 2 + }) - // Verify final state - assert.False(t, c.ReadyForInitialPrompt, "should be false after initial prompt sent") - assert.True(t, c.InitialPromptSent) + // The initial prompt should have been sent. Set a clean screen and + // advance enough ticks for the snapshot loop to record it as an + // agent message and fill the stability buffer (threshold=1). + agent.setScreen("response") + advanceFor(ctx, t, mClock, 2*time.Second) + assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("no initial prompt - normal status logic applies", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) mClock := quartz.NewMock(t) - mClock.Set(now) + agent := &testAgent{screen: "loading..."} cfg := st.PTYConversationConfig{ Clock: mClock, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, - AgentIO: &testAgent{screen: "loading..."}, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { - return false // Agent never ready + return false }, + Logger: discardLogger, } - // Empty initial prompt means no need to wait for readiness - c := st.NewPTY(context.Background(), cfg, "") - c.Snapshot("loading...") + c := st.NewPTY(ctx, cfg) + c.Start(ctx) - // Status should be stable because no initial prompt to wait for + advanceFor(ctx, t, mClock, 1*time.Second) + + // Status should be stable because no initial prompt to wait for. assert.Equal(t, st.ConversationStatusStable, c.Status()) - assert.False(t, c.ReadyForInitialPrompt) - assert.True(t, c.InitialPromptSent) // Set to true when initial prompt is empty }) - t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) { + t.Run("no initial prompt configured - normal status logic applies", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) mClock := quartz.NewMock(t) - mClock.Set(now) agent := &testAgent{screen: "ready"} cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 0, - AgentIO: agent, - ReadyForInitialPrompt: func(message string) bool { - return message == "ready" - }, - SkipWritingMessage: true, - SkipSendMessageStatusCheck: true, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 2 * time.Second, // threshold = 3 + AgentIO: agent, + Logger: discardLogger, } - c := st.NewPTY(context.Background(), cfg, "initial prompt here") - // First, agent becomes ready - c.Snapshot("ready") - assert.Equal(t, st.ConversationStatusStable, c.Status()) - assert.True(t, c.ReadyForInitialPrompt) - assert.False(t, c.InitialPromptSent) + c := st.NewPTY(ctx, cfg) + c.Start(ctx) - // Send the initial prompt - agent.screen = "processing..." - assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) - - // Mark initial prompt as sent (simulating what the server does) - c.InitialPromptSent = true - c.ReadyForInitialPrompt = false + // Fill buffer to reach stability with "ready" screen. + agent.setScreen("ready") + advanceFor(ctx, t, mClock, 3*time.Second) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Now test that status logic works normally after initial prompt is sent - c.Snapshot("processing...") + // After screen changes, status becomes changing. + agent.setScreen("processing...") + advanceFor(ctx, t, mClock, 1*time.Second) + assert.Equal(t, st.ConversationStatusChanging, c.Status()) - // Status should be stable because initial prompt was already sent - // and the readiness check is bypassed + // After screen is stable again (3 identical snapshots), status becomes stable. + advanceFor(ctx, t, mClock, 1*time.Second) + advanceFor(ctx, t, mClock, 1*time.Second) assert.Equal(t, st.ConversationStatusStable, c.Status()) - assert.False(t, c.ReadyForInitialPrompt) - assert.True(t, c.InitialPromptSent) }) } diff --git a/lib/util/util.go b/lib/util/util.go index bbd70d56..4c8be4bb 100644 --- a/lib/util/util.go +++ b/lib/util/util.go @@ -23,6 +23,8 @@ var WaitTimedOut = xerrors.New("timeout waiting for condition") // WaitFor waits for a condition to be true or the timeout to expire. // It will wait for the condition to be true with exponential backoff. +// A single sleep timer is reused across iterations via Reset so that +// mock-clock tests always have a pending timer to advance. func WaitFor(ctx context.Context, timeout WaitTimeout, condition func() (bool, error)) error { clock := timeout.Clock if clock == nil { @@ -41,53 +43,40 @@ func WaitFor(ctx context.Context, timeout WaitTimeout, condition func() (bool, e if timeoutDuration == 0 { timeoutDuration = 10 * time.Second } - timeoutTimer := clock.NewTimer(timeoutDuration) - defer timeoutTimer.Stop() - if minInterval > maxInterval { return xerrors.Errorf("minInterval is greater than maxInterval") } + timeoutTimer := clock.NewTimer(timeoutDuration) + defer timeoutTimer.Stop() + interval := minInterval - if timeout.InitialWait { - initialTimer := clock.NewTimer(interval) - defer initialTimer.Stop() - select { - case <-initialTimer.C: - case <-ctx.Done(): - initialTimer.Stop() - return ctx.Err() - case <-timeoutTimer.C: - initialTimer.Stop() - return WaitTimedOut - } - } + sleepTimer := clock.NewTimer(interval) + defer sleepTimer.Stop() + + waitForTimer := timeout.InitialWait for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-timeoutTimer.C: - return WaitTimedOut - default: - ok, err := condition() - if err != nil { - return err - } - if ok { - return nil - } - sleepTimer := clock.NewTimer(interval) + if waitForTimer { select { case <-sleepTimer.C: case <-ctx.Done(): - sleepTimer.Stop() return ctx.Err() case <-timeoutTimer.C: - sleepTimer.Stop() return WaitTimedOut } - interval = min(interval*2, maxInterval) } + waitForTimer = true + + ok, err := condition() + if err != nil { + return err + } + if ok { + return nil + } + + interval = min(interval*2, maxInterval) + sleepTimer.Reset(interval) } }