diff --git a/go-opencode/go.mod b/go-opencode/go.mod index c862d590ff1..98bc53fd4c2 100644 --- a/go-opencode/go.mod +++ b/go-opencode/go.mod @@ -1,8 +1,6 @@ module github.com/opencode-ai/opencode -go 1.23.0 - -toolchain go1.24.7 +go 1.25 require ( // Eino LLM Framework @@ -32,6 +30,7 @@ require ( github.com/mark3labs/mcp-go v0.43.1 github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/rs/zerolog v1.34.0 + github.com/sergi/go-diff v1.4.0 github.com/spf13/cobra v1.10.1 github.com/sst/opencode-sdk-go v0.0.0-00010101000000-000000000000 github.com/stretchr/testify v1.11.1 @@ -117,7 +116,7 @@ require ( golang.org/x/text v0.28.0 // indirect golang.org/x/tools v0.36.0 // indirect google.golang.org/protobuf v1.36.8 // indirect - gopkg.in/yaml.v2 v2.2.8 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) replace github.com/sst/opencode-sdk-go => ../packages/sdk/go diff --git a/go-opencode/go.sum b/go-opencode/go.sum index f13795d85e2..3f1d62a42f8 100644 --- a/go-opencode/go.sum +++ b/go-opencode/go.sum @@ -233,6 +233,8 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= +github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= @@ -254,6 +256,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -371,8 +374,9 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go-opencode/internal/config/config.go b/go-opencode/internal/config/config.go index 4a9410b77ec..8635335c12e 100644 --- a/go-opencode/internal/config/config.go +++ b/go-opencode/internal/config/config.go @@ -13,22 +13,31 @@ import ( ) // Load loads configuration from multiple sources (priority order): -// 1. Global config (~/.opencode/ - TypeScript compatible) -// 2. Global config (~/.config/opencode/ - XDG compatible) -// 3. Project config (.opencode/) -// 4. OPENCODE_CONFIG file -// 5. OPENCODE_CONFIG_CONTENT inline JSON -// 6. Environment variables +// 1. Global config (~/.opencode/ - TypeScript compatible) +// 2. Global config (~/.config/opencode/ - XDG compatible) +// 3. Project configs discovered while walking up from the working directory +// (opencode.json/opencode.jsonc and .opencode/opencode.json/opencode.jsonc) +// 4. OPENCODE_CONFIG file +// 5. OPENCODE_CONFIG_CONTENT inline JSON +// 6. Environment variables func Load(directory string) (*types.Config, error) { config := &types.Config{ Provider: make(map[string]types.ProviderConfig), Agent: make(map[string]types.AgentConfig), + Keybinds: types.DefaultKeybinds(), } // Track loaded files to avoid duplicates loaded := make(map[string]bool) var loadedFiles []string + // Normalize working directory for deterministic traversal + if directory != "" { + if absDir, err := filepath.Abs(directory); err == nil { + directory = absDir + } + } + loadOnce := func(path string, baseDir string) { absPath, err := filepath.Abs(path) if err != nil { @@ -60,13 +69,25 @@ func Load(directory string) (*types.Config, error) { loadOnce(filepath.Join(globalPath, "opencode.json"), globalPath) loadOnce(filepath.Join(globalPath, "opencode.jsonc"), globalPath) - // 3. Project config + var searchDirs []string if directory != "" { - projectConfigDir := filepath.Join(directory, ".opencode") - loadOnce(filepath.Join(directory, "opencode.json"), directory) - loadOnce(filepath.Join(directory, "opencode.jsonc"), directory) - loadOnce(filepath.Join(projectConfigDir, "opencode.json"), projectConfigDir) - loadOnce(filepath.Join(projectConfigDir, "opencode.jsonc"), projectConfigDir) + searchDirs = walkUpDirectories(directory) + } + + // 3. Project config (root -> leaf for top-level files) + if len(searchDirs) > 0 { + for i := len(searchDirs) - 1; i >= 0; i-- { + dir := searchDirs[i] + loadOnce(filepath.Join(dir, "opencode.jsonc"), dir) + loadOnce(filepath.Join(dir, "opencode.json"), dir) + } + + // .opencode directories (leaf -> root, matching TS loader) + for _, dir := range searchDirs { + projectConfigDir := filepath.Join(dir, ".opencode") + loadOnce(filepath.Join(projectConfigDir, "opencode.jsonc"), projectConfigDir) + loadOnce(filepath.Join(projectConfigDir, "opencode.json"), projectConfigDir) + } } // 4. OPENCODE_CONFIG file override @@ -228,6 +249,9 @@ func mergeConfig(target, source *types.Config) { target.Share = source.Share } + // Merge keybinds (per-field override) + target.Keybinds = types.MergeKeybinds(target.Keybinds, source.Keybinds) + // Merge tools if source.Tools != nil { if target.Tools == nil { @@ -391,3 +415,33 @@ func GetConfigDir() string { // Fall back to XDG location return GetPaths().Config } + +// walkUpDirectories returns a slice of directories from the starting directory +// up to either the git root (if found) or filesystem root. The starting directory +// is always the first element in the returned slice. +func walkUpDirectories(start string) []string { + var dirs []string + current := start + + for { + dirs = append(dirs, current) + + if isGitRoot(current) { + break + } + + parent := filepath.Dir(current) + if parent == current { + break + } + current = parent + } + + return dirs +} + +// isGitRoot checks whether the provided directory contains a .git entry. +func isGitRoot(dir string) bool { + _, err := os.Stat(filepath.Join(dir, ".git")) + return err == nil +} diff --git a/go-opencode/internal/config/config_test.go b/go-opencode/internal/config/config_test.go index 932069c062b..2e35ff93c8f 100644 --- a/go-opencode/internal/config/config_test.go +++ b/go-opencode/internal/config/config_test.go @@ -83,6 +83,32 @@ func TestLoadTypeScriptConfig(t *testing.T) { assert.True(t, coder.Tools["edit"]) } +func TestDefaultKeybindsApplied(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "opencode-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Isolate config paths to avoid picking up user configs + oldHome := os.Getenv("HOME") + oldXdgConfig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("HOME", tmpDir) + os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, ".config")) + defer func() { + os.Setenv("HOME", oldHome) + if oldXdgConfig == "" { + os.Unsetenv("XDG_CONFIG_HOME") + } else { + os.Setenv("XDG_CONFIG_HOME", oldXdgConfig) + } + }() + + cfg, err := Load("") + require.NoError(t, err) + + assert.Equal(t, "ctrl+x", cfg.Keybinds.Leader) + assert.Equal(t, "escape", cfg.Keybinds.SessionInterrupt) +} + func TestJSONCComments(t *testing.T) { // Create a temporary directory tmpDir, err := os.MkdirTemp("", "opencode-test-*") @@ -262,6 +288,54 @@ func TestConfigMerge(t *testing.T) { assert.True(t, cfg.Agent["coder"].Tools["edit"]) } +func TestLoadConfigFromParentDirectory(t *testing.T) { + root := t.TempDir() + + // Isolate HOME/XDG to avoid picking up user config + oldHome := os.Getenv("HOME") + oldXdg := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("HOME", root) + os.Setenv("XDG_CONFIG_HOME", filepath.Join(root, ".config")) + defer func() { + os.Setenv("HOME", oldHome) + if oldXdg == "" { + os.Unsetenv("XDG_CONFIG_HOME") + } else { + os.Setenv("XDG_CONFIG_HOME", oldXdg) + } + }() + + // Parent-level configs + require.NoError(t, os.WriteFile(filepath.Join(root, "opencode.json"), []byte(`{"small_model":"parent-small"}`), 0644)) + + parentConfigDir := filepath.Join(root, ".opencode") + require.NoError(t, os.MkdirAll(parentConfigDir, 0755)) + parentConfig := `{ + "model": "parent-model", + "mcp": { + "parent-server": { + "type": "remote", + "url": "https://parent.example.com/mcp" + } + } + }` + require.NoError(t, os.WriteFile(filepath.Join(parentConfigDir, "opencode.json"), []byte(parentConfig), 0644)) + + workDir := filepath.Join(root, "nested", "project") + require.NoError(t, os.MkdirAll(workDir, 0755)) + + cfg, err := Load(workDir) + require.NoError(t, err) + + assert.Equal(t, "parent-model", cfg.Model) + assert.Equal(t, "parent-small", cfg.SmallModel) + require.NotNil(t, cfg.MCP) + remote, ok := cfg.MCP["parent-server"] + require.True(t, ok) + assert.Equal(t, "remote", remote.Type) + assert.Equal(t, "https://parent.example.com/mcp", remote.URL) +} + func TestEnvVarOverride(t *testing.T) { // Set test environment variables os.Setenv("OPENCODE_MODEL", "env-model") diff --git a/go-opencode/internal/event/bus.go b/go-opencode/internal/event/bus.go index ccb72209527..c0570851909 100644 --- a/go-opencode/internal/event/bus.go +++ b/go-opencode/internal/event/bus.go @@ -15,11 +15,14 @@ import ( type EventType string const ( - SessionCreated EventType = "session.created" - SessionUpdated EventType = "session.updated" - SessionDeleted EventType = "session.deleted" - SessionIdle EventType = "session.idle" - SessionError EventType = "session.error" + SessionCreated EventType = "session.created" + SessionUpdated EventType = "session.updated" + SessionDeleted EventType = "session.deleted" + SessionIdle EventType = "session.idle" + SessionStatus EventType = "session.status" + SessionDiff EventType = "session.diff" + SessionError EventType = "session.error" + SessionCompacted EventType = "session.compacted" MessageCreated EventType = "message.created" MessageUpdated EventType = "message.updated" MessageRemoved EventType = "message.removed" diff --git a/go-opencode/internal/event/types.go b/go-opencode/internal/event/types.go index cdef5bf4bd7..3cc4d52e140 100644 --- a/go-opencode/internal/event/types.go +++ b/go-opencode/internal/event/types.go @@ -25,12 +25,36 @@ type SessionIdleData struct { SessionID string `json:"sessionID"` } +// SessionStatusData is the data for session.status events. +// SDK compatible: uses "sessionID" and "status" fields. +type SessionStatusData struct { + SessionID string `json:"sessionID"` + Status SessionStatusInfo `json:"status"` +} + +// SessionStatusInfo represents the current status of a session. +type SessionStatusInfo struct { + Type string `json:"type"` // "busy" | "idle" +} + +// SessionDiffData is the data for session.diff events. +// SDK compatible: uses "sessionID" and "diff" fields. +type SessionDiffData struct { + SessionID string `json:"sessionID"` + Diff []types.FileDiff `json:"diff"` +} + // SessionErrorData is the data for session.error events. type SessionErrorData struct { SessionID string `json:"sessionID,omitempty"` Error *types.MessageError `json:"error,omitempty"` } +// SessionCompactedData is the data for session.compacted events. +type SessionCompactedData struct { + SessionID string `json:"sessionID"` +} + // MessageCreatedData is the data for message.created events. // SDK compatible: uses "info" field for message object. type MessageCreatedData struct { diff --git a/go-opencode/internal/mcp/client.go b/go-opencode/internal/mcp/client.go index 804428d710c..257cf90942a 100644 --- a/go-opencode/internal/mcp/client.go +++ b/go-opencode/internal/mcp/client.go @@ -88,31 +88,47 @@ func (c *Client) connectServer(ctx context.Context, name string, config *Config) timeout = 5 * time.Second } - var transport sdkmcp.Transport - var connectCtx context.Context - var connectCancel context.CancelFunc + server := &mcpServer{ + name: name, + config: config, + status: StatusConnecting, + } switch config.Type { case TransportTypeRemote: - // For SSE transport, use a long-lived context (background) because - // the SSE connection stays open for the lifetime of the session. - // The HTTP client timeout handles the initial connection timeout. - connectCtx = context.Background() - connectCancel = func() {} // no-op - - httpClient := &http.Client{Timeout: timeout} - transport = &sdkmcp.SSEClientTransport{ - Endpoint: config.URL, - HTTPClient: httpClient, + httpClient := httpClientWithHeaders(nil, config.Headers) + transports := []struct { + name string + transport sdkmcp.Transport + }{ + {name: "streamable", transport: &sdkmcp.StreamableClientTransport{Endpoint: config.URL, HTTPClient: httpClient}}, + {name: "sse", transport: &sdkmcp.SSEClientTransport{Endpoint: config.URL, HTTPClient: httpClient}}, } + var lastErr error + for _, candidate := range transports { + session, err := c.connectWithTransport(context.Background(), candidate.transport, timeout, server) + if err != nil { + lastErr = fmt.Errorf("%s transport: %w", candidate.name, err) + continue + } + server.session = session + server.status = StatusConnected + return server, nil + } + + if lastErr == nil { + lastErr = fmt.Errorf("failed to connect: unknown error") + } + return nil, lastErr + case TransportTypeLocal, TransportTypeStdio: if len(config.Command) == 0 { return nil, fmt.Errorf("empty command") } - // For stdio transport, we can use a timeout context for the initial connection - connectCtx, connectCancel = context.WithTimeout(ctx, timeout) + connectCtx, connectCancel := context.WithTimeout(ctx, timeout) + defer connectCancel() cmd := exec.Command(config.Command[0], config.Command[1:]...) @@ -122,50 +138,82 @@ func (c *Client) connectServer(ctx context.Context, name string, config *Config) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } - transport = &sdkmcp.CommandTransport{Command: cmd} + session, err := c.connectWithTransport(connectCtx, &sdkmcp.CommandTransport{Command: cmd}, timeout, server) + if err != nil { + return nil, err + } + server.session = session + server.status = StatusConnected + return server, nil default: return nil, fmt.Errorf("unknown transport type: %s", config.Type) } +} - server := &mcpServer{ - name: name, - config: config, - status: StatusConnecting, - } - - // Connect using the SDK client - session, err := c.sdkClient.Connect(connectCtx, transport, nil) +func (c *Client) connectWithTransport(ctx context.Context, transport sdkmcp.Transport, timeout time.Duration, server *mcpServer) (*sdkmcp.ClientSession, error) { + session, err := c.sdkClient.Connect(ctx, transport, nil) if err != nil { - connectCancel() return nil, fmt.Errorf("failed to connect: %w", err) } server.session = session - // Get server info from initialization result - initResult := session.InitializeResult() - if initResult != nil { + // Capture server info from initialization result if available + if initResult := session.InitializeResult(); initResult != nil { server.serverInfo = &ServerInfo{ Name: initResult.ServerInfo.Name, Version: initResult.ServerInfo.Version, } } - // List tools - use a separate context for this operation listCtx, listCancel := context.WithTimeout(context.Background(), timeout) defer listCancel() if err := server.listTools(listCtx); err != nil { - // Non-fatal, tools might not be supported - server.tools = []Tool{} + session.Close() + return nil, fmt.Errorf("failed to list tools: %w", err) } - // For stdio transport, cancel the connect context now that setup is complete - // For SSE, connectCancel is a no-op since we use background context - connectCancel() + return session, nil +} + +func httpClientWithHeaders(base *http.Client, headers map[string]string) *http.Client { + if base == nil { + base = &http.Client{} + } - server.status = StatusConnected - return server, nil + // Copy to avoid mutating caller-provided client + client := *base + client.Timeout = 0 // no global timeout; rely on per-request contexts + + if len(headers) == 0 { + return &client + } + + transport := client.Transport + if transport == nil { + transport = http.DefaultTransport + } + + client.Transport = &headerRoundTripper{ + headers: headers, + next: transport, + } + + return &client +} + +type headerRoundTripper struct { + headers map[string]string + next http.RoundTripper +} + +func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + cloned := req.Clone(req.Context()) + for k, v := range h.headers { + cloned.Header.Set(k, v) + } + return h.next.RoundTrip(cloned) } // listTools lists available tools from the server using the SDK. diff --git a/go-opencode/internal/mcp/mcp_test.go b/go-opencode/internal/mcp/mcp_test.go index b03a74bcb3a..d54c1cb3934 100644 --- a/go-opencode/internal/mcp/mcp_test.go +++ b/go-opencode/internal/mcp/mcp_test.go @@ -2,9 +2,13 @@ package mcp import ( "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { @@ -199,6 +203,28 @@ func TestServerInfo(t *testing.T) { assert.Equal(t, "1.0.0", info.Version) } +func TestHTTPClientWithHeaders(t *testing.T) { + headerCh := make(chan string, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerCh <- r.Header.Get("X-Test") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := httpClientWithHeaders(nil, map[string]string{"X-Test": "ok"}) + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + resp.Body.Close() + + assert.Equal(t, "ok", <-headerCh) + assert.Equal(t, time.Duration(0), client.Timeout) +} + func TestContent(t *testing.T) { textContent := Content{Type: "text", Text: "Hello"} assert.Equal(t, "text", textContent.Type) diff --git a/go-opencode/internal/server/handlers_config.go b/go-opencode/internal/server/handlers_config.go index 85b39632e02..e5f96c3497b 100644 --- a/go-opencode/internal/server/handlers_config.go +++ b/go-opencode/internal/server/handlers_config.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "os" + "strings" "github.com/go-chi/chi/v5" @@ -14,6 +15,9 @@ import ( // getConfig handles GET /config func (s *Server) getConfig(w http.ResponseWriter, r *http.Request) { + if s.appConfig != nil { + s.appConfig.Keybinds = types.MergeKeybinds(types.DefaultKeybinds(), s.appConfig.Keybinds) + } writeJSON(w, http.StatusOK, s.appConfig) } @@ -37,19 +41,37 @@ func (s *Server) updateConfig(w http.ResponseWriter, r *http.Request) { } // ProviderModel represents a model in models.dev format for TUI compatibility. +// SDK compatible: uses "capabilities" with nested boolean structure to match TypeScript. type ProviderModel struct { - ID string `json:"id"` - Name string `json:"name"` - ReleaseDate string `json:"release_date"` - Attachment bool `json:"attachment"` - Reasoning bool `json:"reasoning"` - Temperature bool `json:"temperature"` - ToolCall bool `json:"tool_call"` - Cost ModelCost `json:"cost"` - Limit ModelLimit `json:"limit"` - Options map[string]any `json:"options"` - Modalities *ModelModalities `json:"modalities,omitempty"` - Status string `json:"status,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + ReleaseDate string `json:"release_date"` + Capabilities *ModelCapabilities `json:"capabilities"` + Cost ModelCost `json:"cost"` + Limit ModelLimit `json:"limit"` + Options map[string]any `json:"options"` + Status string `json:"status,omitempty"` +} + +// ModelCapabilities represents model capabilities and modalities. +// SDK compatible: matches TypeScript Model.capabilities structure. +type ModelCapabilities struct { + Temperature bool `json:"temperature"` + Reasoning bool `json:"reasoning"` + Attachment bool `json:"attachment"` + ToolCall bool `json:"toolcall"` + Input ModalityCapabilities `json:"input"` + Output ModalityCapabilities `json:"output"` +} + +// ModalityCapabilities represents input/output modality capabilities. +// SDK compatible: matches TypeScript input/output capability structure. +type ModalityCapabilities struct { + Text bool `json:"text"` + Audio bool `json:"audio"` + Image bool `json:"image"` + Video bool `json:"video"` + PDF bool `json:"pdf"` } // ModelCost represents model pricing. @@ -66,12 +88,6 @@ type ModelLimit struct { Output int `json:"output"` } -// ModelModalities represents model input/output modalities. -type ModelModalities struct { - Input []string `json:"input"` - Output []string `json:"output"` -} - // ProviderInfo represents provider information in models.dev format for TUI compatibility. type ProviderInfo struct { ID string `json:"id"` @@ -102,40 +118,49 @@ func getDefaultProviders() []ProviderInfo { ID: "claude-sonnet-4-20250514", Name: "Claude Sonnet 4", ReleaseDate: "2025-05-14", - Attachment: true, - Reasoning: false, - Temperature: true, - ToolCall: true, - Cost: ModelCost{Input: 3.0, Output: 15.0, CacheRead: 0.3, CacheWrite: 3.75}, - Limit: ModelLimit{Context: 200000, Output: 64000}, - Options: map[string]any{}, - Modalities: &ModelModalities{Input: []string{"text", "image", "pdf"}, Output: []string{"text"}}, + Capabilities: &ModelCapabilities{ + Temperature: true, + Reasoning: false, + Attachment: true, + ToolCall: true, + Input: ModalityCapabilities{Text: true, Audio: false, Image: true, Video: false, PDF: true}, + Output: ModalityCapabilities{Text: true, Audio: false, Image: false, Video: false, PDF: false}, + }, + Cost: ModelCost{Input: 3.0, Output: 15.0, CacheRead: 0.3, CacheWrite: 3.75}, + Limit: ModelLimit{Context: 200000, Output: 64000}, + Options: map[string]any{}, }, "claude-opus-4-20250514": { ID: "claude-opus-4-20250514", Name: "Claude Opus 4", ReleaseDate: "2025-05-14", - Attachment: true, - Reasoning: false, - Temperature: true, - ToolCall: true, - Cost: ModelCost{Input: 15.0, Output: 75.0, CacheRead: 1.5, CacheWrite: 18.75}, - Limit: ModelLimit{Context: 200000, Output: 32000}, - Options: map[string]any{}, - Modalities: &ModelModalities{Input: []string{"text", "image", "pdf"}, Output: []string{"text"}}, + Capabilities: &ModelCapabilities{ + Temperature: true, + Reasoning: false, + Attachment: true, + ToolCall: true, + Input: ModalityCapabilities{Text: true, Audio: false, Image: true, Video: false, PDF: true}, + Output: ModalityCapabilities{Text: true, Audio: false, Image: false, Video: false, PDF: false}, + }, + Cost: ModelCost{Input: 15.0, Output: 75.0, CacheRead: 1.5, CacheWrite: 18.75}, + Limit: ModelLimit{Context: 200000, Output: 32000}, + Options: map[string]any{}, }, "claude-3-5-haiku-20241022": { ID: "claude-3-5-haiku-20241022", Name: "Claude 3.5 Haiku", ReleaseDate: "2024-10-22", - Attachment: true, - Reasoning: false, - Temperature: true, - ToolCall: true, - Cost: ModelCost{Input: 0.8, Output: 4.0, CacheRead: 0.08, CacheWrite: 1.0}, - Limit: ModelLimit{Context: 200000, Output: 8192}, - Options: map[string]any{}, - Modalities: &ModelModalities{Input: []string{"text", "image", "pdf"}, Output: []string{"text"}}, + Capabilities: &ModelCapabilities{ + Temperature: true, + Reasoning: false, + Attachment: true, + ToolCall: true, + Input: ModalityCapabilities{Text: true, Audio: false, Image: true, Video: false, PDF: true}, + Output: ModalityCapabilities{Text: true, Audio: false, Image: false, Video: false, PDF: false}, + }, + Cost: ModelCost{Input: 0.8, Output: 4.0, CacheRead: 0.08, CacheWrite: 1.0}, + Limit: ModelLimit{Context: 200000, Output: 8192}, + Options: map[string]any{}, }, }, }, @@ -149,27 +174,33 @@ func getDefaultProviders() []ProviderInfo { ID: "gpt-4o", Name: "GPT-4o", ReleaseDate: "2024-05-13", - Attachment: true, - Reasoning: false, - Temperature: true, - ToolCall: true, - Cost: ModelCost{Input: 2.5, Output: 10.0}, - Limit: ModelLimit{Context: 128000, Output: 16384}, - Options: map[string]any{}, - Modalities: &ModelModalities{Input: []string{"text", "image"}, Output: []string{"text"}}, + Capabilities: &ModelCapabilities{ + Temperature: true, + Reasoning: false, + Attachment: true, + ToolCall: true, + Input: ModalityCapabilities{Text: true, Audio: false, Image: true, Video: false, PDF: false}, + Output: ModalityCapabilities{Text: true, Audio: false, Image: false, Video: false, PDF: false}, + }, + Cost: ModelCost{Input: 2.5, Output: 10.0}, + Limit: ModelLimit{Context: 128000, Output: 16384}, + Options: map[string]any{}, }, "gpt-4o-mini": { ID: "gpt-4o-mini", Name: "GPT-4o Mini", ReleaseDate: "2024-07-18", - Attachment: true, - Reasoning: false, - Temperature: true, - ToolCall: true, - Cost: ModelCost{Input: 0.15, Output: 0.6}, - Limit: ModelLimit{Context: 128000, Output: 16384}, - Options: map[string]any{}, - Modalities: &ModelModalities{Input: []string{"text", "image"}, Output: []string{"text"}}, + Capabilities: &ModelCapabilities{ + Temperature: true, + Reasoning: false, + Attachment: true, + ToolCall: true, + Input: ModalityCapabilities{Text: true, Audio: false, Image: true, Video: false, PDF: false}, + Output: ModalityCapabilities{Text: true, Audio: false, Image: false, Video: false, PDF: false}, + }, + Cost: ModelCost{Input: 0.15, Output: 0.6}, + Limit: ModelLimit{Context: 128000, Output: 16384}, + Options: map[string]any{}, }, }, }, @@ -480,15 +511,234 @@ func (s *Server) readMCPResource(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, result) } +// AgentInfo represents agent information returned by the /agent endpoint. +// SDK compatible: matches TypeScript Agent.Info structure. +type AgentInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Mode string `json:"mode"` + BuiltIn bool `json:"builtIn"` + Prompt string `json:"prompt,omitempty"` + Tools map[string]bool `json:"tools"` + Options map[string]any `json:"options"` + Permission AgentPermissionInfo `json:"permission"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + Model *AgentModelRef `json:"model,omitempty"` + Color string `json:"color,omitempty"` +} + +// AgentPermissionInfo represents agent permission settings. +type AgentPermissionInfo struct { + Edit string `json:"edit,omitempty"` + Bash map[string]string `json:"bash,omitempty"` + WebFetch string `json:"webfetch,omitempty"` + ExternalDir string `json:"external_directory,omitempty"` + DoomLoop string `json:"doom_loop,omitempty"` +} + +// AgentModelRef references a model for an agent. +type AgentModelRef struct { + ProviderID string `json:"providerID"` + ModelID string `json:"modelID"` +} + // listAgents handles GET /agent +// Returns full agent objects matching TypeScript Agent.Info structure. func (s *Server) listAgents(w http.ResponseWriter, r *http.Request) { - agents := []map[string]any{ - {"id": "coder", "name": "Coder", "description": "General coding assistant"}, - {"id": "build", "name": "Build", "description": "Build and test assistant"}, + // Start with built-in agents + agents := getBuiltInAgents() + + // Merge with config agents + if s.appConfig != nil && s.appConfig.Agent != nil { + for name, cfg := range s.appConfig.Agent { + // Find existing or create new + var agent *AgentInfo + for i := range agents { + if agents[i].Name == name { + agent = &agents[i] + break + } + } + + if agent == nil { + // New custom agent + newAgent := AgentInfo{ + Name: name, + Mode: "all", + BuiltIn: false, + Tools: make(map[string]bool), + Options: make(map[string]any), + Permission: AgentPermissionInfo{ + Edit: "allow", + Bash: map[string]string{"*": "allow"}, + WebFetch: "allow", + DoomLoop: "ask", + ExternalDir: "ask", + }, + } + agents = append(agents, newAgent) + agent = &agents[len(agents)-1] + } + + // Apply config overrides + if cfg.Description != "" { + agent.Description = cfg.Description + } + if cfg.Prompt != "" { + agent.Prompt = cfg.Prompt + } + if cfg.Mode != "" { + agent.Mode = cfg.Mode + } + if cfg.Temperature != nil { + agent.Temperature = *cfg.Temperature + } + if cfg.TopP != nil { + agent.TopP = *cfg.TopP + } + if cfg.Color != "" { + agent.Color = cfg.Color + } + if cfg.Model != "" { + // Parse model string "provider/model" + parts := strings.SplitN(cfg.Model, "/", 2) + if len(parts) == 2 { + agent.Model = &AgentModelRef{ + ProviderID: parts[0], + ModelID: parts[1], + } + } + } + if cfg.Tools != nil { + for k, v := range cfg.Tools { + agent.Tools[k] = v + } + } + agent.BuiltIn = false // Mark as customized + } } + writeJSON(w, http.StatusOK, agents) } +// getBuiltInAgents returns the default built-in agents. +func getBuiltInAgents() []AgentInfo { + defaultPermission := AgentPermissionInfo{ + Edit: "allow", + Bash: map[string]string{"*": "allow"}, + WebFetch: "allow", + DoomLoop: "ask", + ExternalDir: "ask", + } + + planPermission := AgentPermissionInfo{ + Edit: "deny", + Bash: map[string]string{ + "cut*": "allow", + "diff*": "allow", + "du*": "allow", + "file *": "allow", + "find * -delete*": "ask", + "find * -exec*": "ask", + "find * -fprint*": "ask", + "find * -fls*": "ask", + "find * -fprintf*": "ask", + "find * -ok*": "ask", + "find *": "allow", + "git diff*": "allow", + "git log*": "allow", + "git show*": "allow", + "git status*": "allow", + "git branch": "allow", + "git branch -v": "allow", + "grep*": "allow", + "head*": "allow", + "less*": "allow", + "ls*": "allow", + "more*": "allow", + "pwd*": "allow", + "rg*": "allow", + "sort --output=*": "ask", + "sort -o *": "ask", + "sort*": "allow", + "stat*": "allow", + "tail*": "allow", + "tree -o *": "ask", + "tree*": "allow", + "uniq*": "allow", + "wc*": "allow", + "whereis*": "allow", + "which*": "allow", + "*": "ask", + }, + WebFetch: "allow", + } + + return []AgentInfo{ + { + Name: "general", + Description: "General-purpose agent for researching complex questions and executing multi-step tasks. Use this agent to execute multiple units of work in parallel.", + Mode: "subagent", + BuiltIn: true, + Tools: map[string]bool{ + "todoread": false, + "todowrite": false, + }, + Options: map[string]any{}, + Permission: defaultPermission, + }, + { + Name: "explore", + Description: `Fast agent specialized for exploring codebases. Use this when you need to quickly find files by patterns (eg. "src/components/**/*.tsx"), search code for keywords (eg. "API endpoints"), or answer questions about the codebase (eg. "how do API endpoints work?"). When calling this agent, specify the desired thoroughness level: "quick" for basic searches, "medium" for moderate exploration, or "very thorough" for comprehensive analysis across multiple locations and naming conventions.`, + Mode: "subagent", + BuiltIn: true, + Tools: map[string]bool{ + "todoread": false, + "todowrite": false, + "edit": false, + "write": false, + }, + Options: map[string]any{}, + Permission: defaultPermission, + Prompt: `You are a file search specialist. You excel at thoroughly navigating and exploring codebases. + +Your strengths: +- Rapidly finding files using glob patterns +- Searching code and text with powerful regex patterns +- Reading and analyzing file contents + +Guidelines: +- Use Glob for broad file pattern matching +- Use Grep for searching file contents with regex +- Use Read when you know the specific file path you need to read +- Use Bash for file operations like copying, moving, or listing directory contents +- Adapt your search approach based on the thoroughness level specified by the caller +- Return file paths as absolute paths in your final response +- For clear communication, avoid using emojis +- Do not create any files, or run bash commands that modify the user's system state in any way + +Complete the user's search request efficiently and report your findings clearly.`, + }, + { + Name: "build", + Mode: "primary", + BuiltIn: true, + Tools: map[string]bool{}, + Options: map[string]any{}, + Permission: defaultPermission, + }, + { + Name: "plan", + Mode: "primary", + BuiltIn: true, + Tools: map[string]bool{}, + Options: map[string]any{}, + Permission: planPermission, + }, + } +} + // getFormatterStatus handles GET /formatter func (s *Server) getFormatterStatus(w http.ResponseWriter, r *http.Request) { if s.formatterManager == nil { @@ -536,28 +786,36 @@ func (s *Server) formatFile(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "Either 'path' or 'paths' is required") } +// CommandInfo represents command information returned by the /command endpoint. +// SDK compatible: matches TypeScript Command.Info structure. +type CommandInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Template string `json:"template"` + Agent string `json:"agent,omitempty"` + Model string `json:"model,omitempty"` + Subtask bool `json:"subtask,omitempty"` +} + // listCommands handles GET /command +// Returns full command objects matching TypeScript Command.Info structure. func (s *Server) listCommands(w http.ResponseWriter, r *http.Request) { - // Start with builtin commands - commands := make([]map[string]any, 0) - for _, cmd := range command.BuiltinCommands() { - commands = append(commands, map[string]any{ - "name": cmd.Name, - "description": cmd.Description, - "source": cmd.Source, - }) - } + commands := make([]CommandInfo, 0) + + // Add built-in commands with templates + builtinCommands := getBuiltInCommands(s.config.Directory) + commands = append(commands, builtinCommands...) - // Add custom commands from executor + // Add custom commands from executor (config and file-based) if s.commandExecutor != nil { for _, cmd := range s.commandExecutor.List() { - commands = append(commands, map[string]any{ - "name": cmd.Name, - "description": cmd.Description, - "source": cmd.Source, - "agent": cmd.Agent, - "model": cmd.Model, - "subtask": cmd.Subtask, + commands = append(commands, CommandInfo{ + Name: cmd.Name, + Description: cmd.Description, + Template: cmd.Template, + Agent: cmd.Agent, + Model: cmd.Model, + Subtask: cmd.Subtask, }) } } @@ -565,6 +823,106 @@ func (s *Server) listCommands(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, commands) } +// getBuiltInCommands returns the built-in commands with their templates. +func getBuiltInCommands(workDir string) []CommandInfo { + return []CommandInfo{ + { + Name: "init", + Description: "create/update AGENTS.md", + Template: `Please analyze this codebase and create an AGENTS.md file containing: +1. Build/lint/test commands - especially for running a single test +2. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc. + +The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long. +If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them. + +If there's already an AGENTS.md, improve it if it's located in ` + workDir + ` + +$ARGUMENTS +`, + }, + { + Name: "review", + Description: "review changes [commit|branch|pr], defaults to uncommitted", + Template: `You are a code reviewer. Your job is to review code changes and provide actionable feedback. + +--- + +Input: $ARGUMENTS + +--- + +## Determining What to Review + +Based on the input provided, determine which type of review to perform: + +1. **No arguments (default)**: Review all uncommitted changes + - Run: ` + "`git diff`" + ` for unstaged changes + - Run: ` + "`git diff --cached`" + ` for staged changes + +2. **Commit hash** (40-char SHA or short hash): Review that specific commit + - Run: ` + "`git show $ARGUMENTS`" + ` + +3. **Branch name**: Compare current branch to the specified branch + - Run: ` + "`git diff $ARGUMENTS...HEAD`" + ` + +4. **PR URL or number** (contains "github.com" or "pull" or looks like a PR number): Review the pull request + - Run: ` + "`gh pr view $ARGUMENTS`" + ` to get PR context + - Run: ` + "`gh pr diff $ARGUMENTS`" + ` to get the diff + +Use best judgement when processing input. + +--- + +## What to Look For + +**Bugs** - Your primary focus. +- Logic errors, off-by-one mistakes, incorrect conditionals +- Edge cases: null/empty inputs, error conditions, race conditions +- Security issues: injection, auth bypass, data exposure +- Broken error handling that swallows failures + +**Structure** - Does the code fit the codebase? +- Does it follow existing patterns and conventions? +- Are there established abstractions it should use but doesn't? + +**Performance** - Only flag if obviously problematic. +- O(n²) on unbounded data, N+1 queries, blocking I/O on hot paths + +## Before You Flag Something + +Be certain. If you're going to call something a bug, you need to be confident it actually is one. + +- Only review the changes - do not review pre-existing code that wasn't modified +- Don't flag something as a bug if you're unsure - investigate first +- Don't flag style preferences as issues +- Don't invent hypothetical problems - if an edge case matters, explain the realistic scenario where it breaks +- If you need more context to be sure, use the tools below to get it + +## Tools + +Use these to inform your review: + +- **Explore agent** - Find how existing code handles similar problems. Check patterns, conventions, and prior art before claiming something doesn't fit. +- **Exa Code Context** - Verify correct usage of libraries/APIs before flagging something as wrong. +- **Exa Web Search** - Research best practices if you're unsure about a pattern. + +If you're uncertain about something and can't verify it with these tools, say "I'm not sure about X" rather than flagging it as a definite issue. + +## Tone and Approach + +1. If there is a bug, be direct and clear about why it is a bug. +2. You should clearly communicate severity of issues, do not claim issues are more severe than they actually are. +3. Critiques should clearly and explicitly communicate the scenarios, environments, or inputs that are necessary for the bug to arise. The comment should immediately indicate that the issue's severity depends on these factors. +4. Your tone should be matter-of-fact and not accusatory or overly positive. It should read as a helpful AI assistant suggestion without sounding too much like a human reviewer. +5. Write in a manner that allows reader to quickly understand issue without reading too closely. +6. AVOID flattery, do not give any comments that are not helpful to the reader. Avoid phrasing like "Great job ...", "Thanks for ...". +`, + Subtask: true, + }, + } +} + // executeCommand handles POST /command/{name} func (s *Server) executeCommand(w http.ResponseWriter, r *http.Request) { if s.commandExecutor == nil { diff --git a/go-opencode/internal/server/handlers_file.go b/go-opencode/internal/server/handlers_file.go index f3ae24f48c2..8b4ca7ae7b3 100644 --- a/go-opencode/internal/server/handlers_file.go +++ b/go-opencode/internal/server/handlers_file.go @@ -258,6 +258,21 @@ var symbolKindsFilter = map[lsp.SymbolKind]bool{ lsp.SymbolKindStruct: true, // 23 } +// getVCSInfo handles GET /vcs +// Returns the current git branch name. +func (s *Server) getVCSInfo(w http.ResponseWriter, r *http.Request) { + directory := getDirectory(r.Context()) + + // Get current branch + cmd := exec.Command("git", "branch", "--show-current") + cmd.Dir = directory + branch, _ := cmd.Output() + + writeJSON(w, http.StatusOK, map[string]any{ + "branch": strings.TrimSpace(string(branch)), + }) +} + // searchSymbols handles GET /find/symbol func (s *Server) searchSymbols(w http.ResponseWriter, r *http.Request) { query := r.URL.Query().Get("query") diff --git a/go-opencode/internal/server/handlers_message.go b/go-opencode/internal/server/handlers_message.go index 99298827b91..eacc293b4fe 100644 --- a/go-opencode/internal/server/handlers_message.go +++ b/go-opencode/internal/server/handlers_message.go @@ -86,6 +86,7 @@ func (s *Server) sendMessage(w http.ResponseWriter, r *http.Request) { } // Create user message + // SDK compatible: user messages include summary field (initially with empty diffs) userMsg := &types.Message{ ID: generateID(), SessionID: sessionID, @@ -93,6 +94,9 @@ func (s *Server) sendMessage(w http.ResponseWriter, r *http.Request) { Agent: req.Agent, Model: req.Model, Tools: req.Tools, + Summary: &types.UserMessageSummary{ + Diffs: []types.FileDiff{}, // SDK compatible: empty diffs array + }, Time: types.MessageTime{ Created: nowMillis(), }, diff --git a/go-opencode/internal/server/handlers_session.go b/go-opencode/internal/server/handlers_session.go index 4f575da1599..dab5ab6bac7 100644 --- a/go-opencode/internal/server/handlers_session.go +++ b/go-opencode/internal/server/handlers_session.go @@ -235,17 +235,34 @@ func (s *Server) unshareSession(w http.ResponseWriter, r *http.Request) { writeSuccess(w) } +// SummarizeSessionRequest represents the request body for summarizing a session. +type SummarizeSessionRequest struct { + ProviderID string `json:"providerID"` + ModelID string `json:"modelID"` +} + // summarizeSession handles POST /session/{sessionID}/summarize func (s *Server) summarizeSession(w http.ResponseWriter, r *http.Request) { sessionID := chi.URLParam(r, "sessionID") - summary, err := s.sessionService.Summarize(r.Context(), sessionID) + var req SummarizeSessionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "Invalid JSON body: providerID and modelID are required") + return + } + + if req.ProviderID == "" || req.ModelID == "" { + writeError(w, http.StatusBadRequest, ErrCodeInvalidRequest, "providerID and modelID are required") + return + } + + err := s.sessionService.Summarize(r.Context(), sessionID, req.ProviderID, req.ModelID) if err != nil { writeError(w, http.StatusInternalServerError, ErrCodeInternalError, err.Error()) return } - writeJSON(w, http.StatusOK, summary) + writeJSON(w, http.StatusOK, true) } // initSession handles POST /session/{sessionID}/init diff --git a/go-opencode/internal/server/handlers_test.go b/go-opencode/internal/server/handlers_test.go index 00249798f3b..9b8df3bdd94 100644 --- a/go-opencode/internal/server/handlers_test.go +++ b/go-opencode/internal/server/handlers_test.go @@ -238,6 +238,31 @@ func TestGetConfig(t *testing.T) { } } +func TestGetConfigIncludesKeybinds(t *testing.T) { + srv := setupTestServer(t) + + req := httptest.NewRequest("GET", "/config", nil) + w := httptest.NewRecorder() + + srv.getConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + var cfg types.Config + if err := json.NewDecoder(w.Body).Decode(&cfg); err != nil { + t.Fatalf("Failed to decode: %v", err) + } + + if cfg.Keybinds.SessionInterrupt != "escape" { + t.Errorf("Expected default session_interrupt to be escape, got %q", cfg.Keybinds.SessionInterrupt) + } + if cfg.Keybinds.Leader != "ctrl+x" { + t.Errorf("Expected default leader to be ctrl+x, got %q", cfg.Keybinds.Leader) + } +} + func TestReadFile_NotFound(t *testing.T) { srv := setupTestServer(t) diff --git a/go-opencode/internal/server/routes.go b/go-opencode/internal/server/routes.go index 1a16a86ff73..5629e40de70 100644 --- a/go-opencode/internal/server/routes.go +++ b/go-opencode/internal/server/routes.go @@ -86,6 +86,9 @@ func (s *Server) setupRoutes() { // Authentication r.Put("/auth/{providerID}", s.setAuth) + // VCS (Version Control System) + r.Get("/vcs", s.getVCSInfo) + // Advanced features r.Get("/lsp", s.getLSPStatus) r.Get("/agent", s.listAgents) diff --git a/go-opencode/internal/server/sse.go b/go-opencode/internal/server/sse.go index 4b3b94fba70..611820a3950 100644 --- a/go-opencode/internal/server/sse.go +++ b/go-opencode/internal/server/sse.go @@ -17,11 +17,22 @@ import ( "encoding/json" "fmt" "net/http" + "sync/atomic" "time" "github.com/opencode-ai/opencode/internal/event" ) +// SSE event counter for debugging +var sseEventCounter uint64 + +// SDKEvent represents an SDK-compatible event with proper JSON field ordering. +// TypeScript expects: {"type": "...", "properties": {...}} +type SDKEvent struct { + Type event.EventType `json:"type"` + Properties any `json:"properties"` +} + const ( // SSEHeartbeatInterval is the interval for SSE heartbeats. SSEHeartbeatInterval = 30 * time.Second @@ -31,28 +42,67 @@ const ( type sseWriter struct { w http.ResponseWriter flusher http.Flusher + rc *http.ResponseController } // newSSEWriter creates a new SSE writer. func newSSEWriter(w http.ResponseWriter) (*sseWriter, error) { + // Use ResponseController for more reliable flushing (Go 1.20+) + rc := http.NewResponseController(w) + + // Try to get flusher interface as well flusher, ok := w.(http.Flusher) if !ok { return nil, fmt.Errorf("streaming not supported") } - return &sseWriter{w: w, flusher: flusher}, nil + return &sseWriter{w: w, flusher: flusher, rc: rc}, nil } -// writeEvent writes an SSE event. +// writeEvent writes an SSE event with optional throttling. func (s *sseWriter) writeEvent(eventType string, data any) error { jsonData, err := json.Marshal(data) if err != nil { return err } - fmt.Fprintf(s.w, "event: %s\n", eventType) - fmt.Fprintf(s.w, "data: %s\n\n", jsonData) - s.flusher.Flush() + // Log SSE event for debugging - include event type from data if available + count := atomic.AddUint64(&sseEventCounter, 1) + dataType := "" + switch d := data.(type) { + case SDKEvent: + dataType = string(d.Type) + case map[string]any: + // Handle both string and event.EventType (which is a string alias) + switch t := d["type"].(type) { + case string: + dataType = t + case event.EventType: + dataType = string(t) + } + } + + t1 := time.Now() + fmt.Printf("[sse] #%d PRE-WRITE event=%s dataType=%s time=%s\n", + count, eventType, dataType, t1.Format("15:04:05.000")) + + // Write SSE format: event type, data, and blank line + _, err = fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", eventType, jsonData) + if err != nil { + return err + } + t2 := time.Now() + + // Flush immediately using ResponseController (more reliable than Flusher interface) + // This ensures data is sent even through middleware wrappers + if flushErr := s.rc.Flush(); flushErr != nil { + // Fallback to traditional flusher + s.flusher.Flush() + } + t3 := time.Now() + + fmt.Printf("[sse] #%d POST-FLUSH event=%s dataType=%s write=%v flush=%v total=%v\n", + count, eventType, dataType, t2.Sub(t1), t3.Sub(t2), t3.Sub(t1)) return nil } @@ -66,6 +116,7 @@ func (s *sseWriter) writeHeartbeat() { // allEvents handles SSE for all events (used by /event endpoint). // This is the main event endpoint that the TUI connects to. func (srv *Server) allEvents(w http.ResponseWriter, r *http.Request) { + fmt.Printf("[sse] allEvents: new connection from %s\n", r.RemoteAddr) // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -83,23 +134,27 @@ func (srv *Server) allEvents(w http.ResponseWriter, r *http.Request) { sse.flusher.Flush() // Send server.connected event first (SDK compatible) - connectedEvent := map[string]any{ - "type": "server.connected", - "properties": map[string]any{}, + connectedEvent := SDKEvent{ + Type: "server.connected", + Properties: map[string]any{}, } if err := sse.writeEvent("message", connectedEvent); err != nil { return } - // Channel for events - events := make(chan event.Event, 100) + // Channel for events - use small buffer for low-latency streaming + events := make(chan event.Event, 10) // Subscribe to all events + var recvCounter uint64 unsub := event.SubscribeAll(func(e event.Event) { + count := atomic.AddUint64(&recvCounter, 1) + fmt.Printf("[sse] #%d recv type=%s time=%s\n", + count, e.Type, time.Now().Format("15:04:05.000")) select { case events <- e: default: - // Drop event if channel is full + fmt.Printf("[sse] #%d DROPPED (channel full) type=%s\n", count, e.Type) } }) defer unsub() @@ -108,18 +163,23 @@ func (srv *Server) allEvents(w http.ResponseWriter, r *http.Request) { ticker := time.NewTicker(SSEHeartbeatInterval) defer ticker.Stop() + fmt.Printf("[sse] allEvents: entering select loop, channel len=%d\n", len(events)) + // Wait for client disconnect or context cancellation for { select { case <-r.Context().Done(): + fmt.Printf("[sse] allEvents: context done\n") return case e := <-events: - // SDK compatible format: use "properties" instead of "data" - data := map[string]any{ - "type": e.Type, - "properties": e.Data, + fmt.Printf("[sse] allEvents: got event from channel type=%s\n", e.Type) + // SDK compatible format: use struct for proper field ordering + data := SDKEvent{ + Type: e.Type, + Properties: e.Data, } if err := sse.writeEvent("message", data); err != nil { + fmt.Printf("[sse] allEvents: write error: %v\n", err) return } case <-ticker.C: @@ -130,6 +190,7 @@ func (srv *Server) allEvents(w http.ResponseWriter, r *http.Request) { // globalEvents handles SSE for all events. func (srv *Server) globalEvents(w http.ResponseWriter, r *http.Request) { + fmt.Printf("[sse] globalEvents: new connection from %s\n", r.RemoteAddr) // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -147,15 +208,19 @@ func (srv *Server) globalEvents(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) sse.flusher.Flush() - // Channel for events - events := make(chan event.Event, 100) + // Channel for events - use small buffer for low-latency streaming + events := make(chan event.Event, 10) // Subscribe to all events + var recvCounter uint64 unsub := event.SubscribeAll(func(e event.Event) { + count := atomic.AddUint64(&recvCounter, 1) + fmt.Printf("[sse-global] #%d recv type=%s time=%s\n", + count, e.Type, time.Now().Format("15:04:05.000")) select { case events <- e: default: - // Drop event if channel is full + fmt.Printf("[sse-global] #%d DROPPED (channel full) type=%s\n", count, e.Type) } }) defer unsub() @@ -170,10 +235,10 @@ func (srv *Server) globalEvents(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): return case e := <-events: - // SDK compatible format: use "properties" instead of "data" - data := map[string]any{ - "type": e.Type, - "properties": e.Data, + // SDK compatible format: use struct for proper field ordering + data := SDKEvent{ + Type: e.Type, + Properties: e.Data, } if err := sse.writeEvent("message", data); err != nil { return @@ -209,8 +274,8 @@ func (srv *Server) sessionEvents(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) sse.flusher.Flush() - // Channel for events - events := make(chan event.Event, 100) + // Channel for events - use small buffer for low-latency streaming + events := make(chan event.Event, 10) // Filter for session-specific events unsub := event.SubscribeAll(func(e event.Event) { @@ -233,10 +298,10 @@ func (srv *Server) sessionEvents(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): return case e := <-events: - // SDK compatible format: use "properties" instead of "data" - data := map[string]any{ - "type": e.Type, - "properties": e.Data, + // SDK compatible format: use struct for proper field ordering + data := SDKEvent{ + Type: e.Type, + Properties: e.Data, } if err := sse.writeEvent("message", data); err != nil { return @@ -263,6 +328,8 @@ func (srv *Server) eventBelongsToSession(e event.Event, sessionID string) bool { return data.Info != nil && data.Info.ID == sessionID case event.SessionDeletedData: return data.Info != nil && data.Info.ID == sessionID + case event.SessionDiffData: + return data.SessionID == sessionID case event.PermissionUpdatedData: return data.SessionID == sessionID case event.PermissionRepliedData: diff --git a/go-opencode/internal/session/compact.go b/go-opencode/internal/session/compact.go index 599719201d9..7a90e611008 100644 --- a/go-opencode/internal/session/compact.go +++ b/go-opencode/internal/session/compact.go @@ -9,6 +9,7 @@ import ( "github.com/cloudwego/eino/schema" + "github.com/opencode-ai/opencode/internal/event" "github.com/opencode-ai/opencode/internal/provider" "github.com/opencode-ai/opencode/pkg/types" ) @@ -120,7 +121,7 @@ func (p *Processor) compactMessages( // Create compaction marker in session // This would be used to inject the summary into future prompts session.Summary.Diffs = append(session.Summary.Diffs, types.FileDiff{ - Path: "__compaction__", + File: "__compaction__", Before: "", After: summary.String(), }) @@ -177,20 +178,231 @@ func buildSummaryPrompt(ctx context.Context, p *Processor, messages []*types.Mes return prompt.String() } -// CompactionPart represents a summary of compacted messages. -type CompactionPart struct { - ID string `json:"id"` - Type string `json:"type"` // always "compaction" - Summary string `json:"summary"` - Count int `json:"count"` // Number of messages summarized -} - -func (p *CompactionPart) PartType() string { return "compaction" } -func (p *CompactionPart) PartID() string { return p.ID } - // estimateTokens provides a rough estimate of token count. func estimateTokens(text string) int { // Rough estimate: ~4 characters per token return len(text) / 4 } +// compactionSystemPrompt is the system prompt for generating summaries. +const compactionSystemPrompt = `You are a conversation summarizer. Create a concise summary of the conversation that preserves key context for continuing the discussion. + +Focus on: +1. What was accomplished +2. Current work in progress +3. Files involved +4. Next steps +5. Any key user requests or constraints + +Be concise but detailed enough that work can continue seamlessly.` + +// processCompaction handles a compaction request by summarizing the conversation. +func (p *Processor) processCompaction( + ctx context.Context, + sessionID string, + messages []*types.Message, + compactionPart *types.CompactionPart, + callback ProcessCallback, +) error { + fmt.Printf("[compact] Processing compaction for session %s\n", sessionID) + + // Find session + session, err := p.findSession(ctx, sessionID) + if err != nil { + return err + } + + // Get the last user message (which contains the compaction part) + lastMsg := messages[len(messages)-1] + + // Get provider and model from the user message + providerID := p.defaultProviderID + modelID := p.defaultModelID + if lastMsg.Model != nil { + providerID = lastMsg.Model.ProviderID + modelID = lastMsg.Model.ModelID + } + + prov, err := p.providerRegistry.Get(providerID) + if err != nil { + return fmt.Errorf("provider not found: %w", err) + } + + model, err := p.providerRegistry.GetModel(providerID, modelID) + if err != nil { + return fmt.Errorf("model not found: %w", err) + } + + // Set compacting flag on session + now := time.Now().UnixMilli() + session.Time.Compacting = &now + p.storage.Put(ctx, []string{"session", session.ProjectID, session.ID}, session) + + defer func() { + session.Time.Compacting = nil + p.storage.Put(ctx, []string{"session", session.ProjectID, session.ID}, session) + }() + + // Build summary prompt from all messages except the compaction request itself + summaryPrompt := buildSummaryPrompt(ctx, p, messages[:len(messages)-1]) + summaryPrompt += "\n\nSummarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly." + + // Create assistant message with summary flag + assistantMsg := &types.Message{ + ID: generatePartID(), + SessionID: sessionID, + Role: "assistant", + ParentID: lastMsg.ID, + ProviderID: providerID, + ModelID: modelID, + Mode: lastMsg.Agent, + IsSummary: true, // Mark as summary message + Path: &types.MessagePath{ + Cwd: session.Directory, + Root: session.Directory, + }, + Time: types.MessageTime{ + Created: now, + }, + Tokens: &types.TokenUsage{Input: 0, Output: 0}, + } + + // Save initial message + if err := p.storage.Put(ctx, []string{"message", sessionID, assistantMsg.ID}, assistantMsg); err != nil { + return fmt.Errorf("failed to save message: %w", err) + } + + // Notify callback + callback(assistantMsg, nil) + + // Publish message created event + event.Publish(event.Event{ + Type: event.MessageCreated, + Data: event.MessageCreatedData{Info: assistantMsg}, + }) + + // Create text part for streaming the summary + textPart := &types.TextPart{ + ID: generatePartID(), + SessionID: sessionID, + MessageID: assistantMsg.ID, + Type: "text", + Text: "", + } + + // Save initial part + if err := p.storage.Put(ctx, []string{"part", assistantMsg.ID, textPart.ID}, textPart); err != nil { + return fmt.Errorf("failed to save part: %w", err) + } + + // Publish part created event + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{Part: textPart}, + }) + + // Generate summary using LLM + systemMsg := &schema.Message{ + Role: schema.System, + Content: compactionSystemPrompt, + } + + userMsg := &schema.Message{ + Role: schema.User, + Content: summaryPrompt, + } + + stream, err := prov.CreateCompletion(ctx, &provider.CompletionRequest{ + Model: model.ID, + Messages: []*schema.Message{systemMsg, userMsg}, + MaxTokens: DefaultCompactionConfig.SummaryMaxTokens, + }) + if err != nil { + return fmt.Errorf("failed to create completion: %w", err) + } + defer stream.Close() + + // Stream the response + var fullText strings.Builder + for { + msg, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("stream error: %w", err) + } + + fullText.WriteString(msg.Content) + textPart.Text = fullText.String() + + // Save updated part + p.storage.Put(ctx, []string{"part", assistantMsg.ID, textPart.ID}, textPart) + + // Publish streaming update with delta + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{ + Part: textPart, + Delta: msg.Content, + }, + }) + } + + // Update message with final token counts + // (In a full implementation, we'd get actual token counts from the provider) + assistantMsg.Tokens = &types.TokenUsage{ + Input: estimateTokens(summaryPrompt), + Output: estimateTokens(fullText.String()), + } + p.storage.Put(ctx, []string{"message", sessionID, assistantMsg.ID}, assistantMsg) + + // Publish message updated event + event.Publish(event.Event{ + Type: event.MessageUpdated, + Data: event.MessageUpdatedData{Info: assistantMsg}, + }) + + // Publish session.compacted event + event.Publish(event.Event{ + Type: event.SessionCompacted, + Data: event.SessionCompactedData{SessionID: sessionID}, + }) + + fmt.Printf("[compact] Compaction complete for session %s\n", sessionID) + + // If auto-compaction, add a "Continue if you have next steps" message + if compactionPart.Auto { + continueMsg := &types.Message{ + ID: generatePartID(), + SessionID: sessionID, + Role: "user", + Agent: lastMsg.Agent, + Model: lastMsg.Model, + Time: types.MessageTime{ + Created: time.Now().UnixMilli(), + }, + } + p.storage.Put(ctx, []string{"message", sessionID, continueMsg.ID}, continueMsg) + + continuePart := &types.TextPart{ + ID: generatePartID(), + SessionID: sessionID, + MessageID: continueMsg.ID, + Type: "text", + Text: "Continue if you have next steps", + } + p.storage.Put(ctx, []string{"part", continueMsg.ID, continuePart.ID}, continuePart) + + event.Publish(event.Event{ + Type: event.MessageCreated, + Data: event.MessageCreatedData{Info: continueMsg}, + }) + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{Part: continuePart}, + }) + } + + return nil +} diff --git a/go-opencode/internal/session/loop.go b/go-opencode/internal/session/loop.go index a4483240691..ffac6ab0eca 100644 --- a/go-opencode/internal/session/loop.go +++ b/go-opencode/internal/session/loop.go @@ -58,13 +58,28 @@ func (p *Processor) runLoop( var session types.Session if err := p.storage.Get(ctx, []string{"session", sessionID}, &session); err != nil { // Try to find session in any project - session, err := p.findSession(ctx, sessionID) - if err != nil { - return fmt.Errorf("session not found: %w", err) + foundSession, findErr := p.findSession(ctx, sessionID) + if findErr != nil { + return fmt.Errorf("session not found: %w", findErr) } - _ = session + session = *foundSession } + // Emit initial session.updated event + event.Publish(event.Event{ + Type: event.SessionUpdated, + Data: event.SessionUpdatedData{Info: &session}, + }) + + // Emit initial session.diff event (empty diffs at start) + event.Publish(event.Event{ + Type: event.SessionDiff, + Data: event.SessionDiffData{ + SessionID: sessionID, + Diff: session.Summary.Diffs, + }, + }) + // Load messages messages, err := p.loadMessages(ctx, sessionID) if err != nil { @@ -82,15 +97,24 @@ func (p *Processor) runLoop( } // Load and log user message parts userParts, _ := p.loadParts(ctx, lastMsg.ID) + var compactionPart *types.CompactionPart for i, part := range userParts { switch pt := part.(type) { case *types.TextPart: fmt.Printf("[loop] User message part %d: type=text content=%q\n", i, truncateStr(pt.Text, 50)) + case *types.CompactionPart: + fmt.Printf("[loop] User message part %d: type=compaction auto=%v\n", i, pt.Auto) + compactionPart = pt default: fmt.Printf("[loop] User message part %d: type=%T\n", i, pt) } } + // If this is a compaction request, process it + if compactionPart != nil { + return p.processCompaction(ctx, sessionID, messages, compactionPart, callback) + } + // Get provider and model providerID := p.defaultProviderID modelID := p.defaultModelID @@ -125,9 +149,14 @@ func (p *Processor) runLoop( ID: generatePartID(), SessionID: sessionID, Role: "assistant", + ParentID: lastMsg.ID, // Link to the user message that prompted this ProviderID: providerID, ModelID: modelID, Mode: agent.Name, // Agent name (e.g., "Coder", "Build") - required by TUI + Path: &types.MessagePath{ + Cwd: session.Directory, // Current working directory from session + Root: session.Directory, // Root directory (same as cwd for now) + }, Time: types.MessageTime{ Created: now, }, @@ -315,9 +344,10 @@ func (p *Processor) runLoop( p.saveMessage(ctx, sessionID, assistantMsg) return nil - case "tool_use", "tool_calls": + case "tool_use", "tool_calls", "tool-calls": // Execute tools and continue loop - fmt.Printf("[loop] Got tool_use/tool_calls, calling executeToolCalls with %d parts\n", len(state.parts)) + // Note: "tool-calls" is SDK compatible (TypeScript), "tool_use" is from some providers + fmt.Printf("[loop] Got tool_use/tool_calls/tool-calls, calling executeToolCalls with %d parts\n", len(state.parts)) if err := p.executeToolCalls(ctx, state, agent, callback); err != nil { fmt.Printf("[loop] executeToolCalls returned error: %v\n", err) // Tool execution errors don't stop the loop diff --git a/go-opencode/internal/session/processor.go b/go-opencode/internal/session/processor.go index d66108da89f..199f8c7af38 100644 --- a/go-opencode/internal/session/processor.go +++ b/go-opencode/internal/session/processor.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/opencode-ai/opencode/internal/event" "github.com/opencode-ai/opencode/internal/permission" "github.com/opencode-ai/opencode/internal/provider" "github.com/opencode-ai/opencode/internal/storage" @@ -104,6 +105,15 @@ func (p *Processor) Process(ctx context.Context, sessionID string, agent *Agent, p.sessions[sessionID] = state p.mu.Unlock() + // Emit session.status busy event + event.Publish(event.Event{ + Type: event.SessionStatus, + Data: event.SessionStatusData{ + SessionID: sessionID, + Status: event.SessionStatusInfo{Type: "busy"}, + }, + }) + // Ensure cleanup defer func() { p.mu.Lock() @@ -114,6 +124,21 @@ func (p *Processor) Process(ctx context.Context, sessionID string, agent *Agent, waiter <- nil } p.mu.Unlock() + + // Emit session.status idle event (SDK compatible: TUI uses this to stop progress bar) + event.Publish(event.Event{ + Type: event.SessionStatus, + Data: event.SessionStatusData{ + SessionID: sessionID, + Status: event.SessionStatusInfo{Type: "idle"}, + }, + }) + + // Emit session.idle event when processing completes + event.Publish(event.Event{ + Type: event.SessionIdle, + Data: event.SessionIdleData{SessionID: sessionID}, + }) }() // Run the agentic loop diff --git a/go-opencode/internal/session/processor_test.go b/go-opencode/internal/session/processor_test.go index 99a7fbeaee8..523b6b8b805 100644 --- a/go-opencode/internal/session/processor_test.go +++ b/go-opencode/internal/session/processor_test.go @@ -293,15 +293,18 @@ func TestToolState(t *testing.T) { } func TestCompactionPart(t *testing.T) { - part := &CompactionPart{ - ID: "test-id", - Type: "compaction", - Summary: "This is a summary", - Count: 5, + part := &types.CompactionPart{ + ID: "test-id", + SessionID: "session-1", + MessageID: "msg-1", + Type: "compaction", + Auto: false, } assert.Equal(t, "compaction", part.PartType()) assert.Equal(t, "test-id", part.PartID()) + assert.Equal(t, "session-1", part.PartSessionID()) + assert.Equal(t, "msg-1", part.PartMessageID()) } func TestSessionState(t *testing.T) { diff --git a/go-opencode/internal/session/service.go b/go-opencode/internal/session/service.go index 4885a52cb86..697457e42e1 100644 --- a/go-opencode/internal/session/service.go +++ b/go-opencode/internal/session/service.go @@ -12,6 +12,7 @@ import ( "github.com/oklog/ulid/v2" + "github.com/opencode-ai/opencode/internal/event" "github.com/opencode-ai/opencode/internal/permission" "github.com/opencode-ai/opencode/internal/provider" "github.com/opencode-ai/opencode/internal/storage" @@ -273,6 +274,12 @@ func (s *Service) Fork(ctx context.Context, sessionID, messageID string) (*types // Abort aborts an active session. func (s *Service) Abort(ctx context.Context, sessionID string) error { + // Use the processor's abort mechanism which cancels the context + if s.processor != nil { + return s.processor.Abort(sessionID) + } + + // Fallback to channel-based abort (legacy) s.mu.Lock() defer s.mu.Unlock() @@ -317,14 +324,85 @@ func (s *Service) Unshare(ctx context.Context, sessionID string) error { return s.storage.Put(ctx, []string{"session", session.ProjectID, session.ID}, session) } -// Summarize generates a summary of the session. -func (s *Service) Summarize(ctx context.Context, sessionID string) (*types.SessionSummary, error) { +// Summarize initiates a compaction/summarization of the session. +// This creates a user message with a compaction part and triggers the processing loop. +func (s *Service) Summarize(ctx context.Context, sessionID, providerID, modelID string) error { session, err := s.Get(ctx, sessionID) if err != nil { - return nil, err + return err + } + + // Get the current agent from the last user message + messages, err := s.GetMessages(ctx, sessionID) + if err != nil { + return err + } + + currentAgent := "default" + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + if messages[i].Agent != "" { + currentAgent = messages[i].Agent + } + break + } + } + + // Create a user message with a compaction part + now := time.Now().UnixMilli() + userMsg := &types.Message{ + ID: ulid.Make().String(), + SessionID: sessionID, + Role: "user", + Agent: currentAgent, + Model: &types.ModelRef{ + ProviderID: providerID, + ModelID: modelID, + }, + Time: types.MessageTime{ + Created: now, + }, + } + + // Store the user message + if err := s.storage.Put(ctx, []string{"message", sessionID, userMsg.ID}, userMsg); err != nil { + return err + } + + // Publish message created event + event.Publish(event.Event{ + Type: event.MessageCreated, + Data: event.MessageCreatedData{Info: userMsg}, + }) + + // Create the compaction part + compactionPart := &types.CompactionPart{ + ID: ulid.Make().String(), + SessionID: sessionID, + MessageID: userMsg.ID, + Type: "compaction", + Auto: false, } - return &session.Summary, nil + // Store the compaction part + if err := s.storage.Put(ctx, []string{"part", userMsg.ID, compactionPart.ID}, compactionPart); err != nil { + return err + } + + // Publish part updated event + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{Part: compactionPart}, + }) + + // Trigger the processing loop + if s.processor != nil { + go func() { + s.processor.Process(context.Background(), session.ID, nil, nil) + }() + } + + return nil } // GetDiffs returns diffs for a session. diff --git a/go-opencode/internal/session/stream.go b/go-opencode/internal/session/stream.go index c6f4846316f..1d4f7cb53c6 100644 --- a/go-opencode/internal/session/stream.go +++ b/go-opencode/internal/session/stream.go @@ -32,15 +32,25 @@ func (p *Processor) processStream( currentToolParts = make(map[string]*types.ToolPart) accumulatedToolInputs = make(map[string]string) - // Emit step start - stepStartPart := &types.TextPart{ - ID: generatePartID(), - Type: "step-start", + // Emit step-start part at the beginning of inference + stepStartPart := &types.StepStartPart{ + ID: generatePartID(), + SessionID: state.message.SessionID, + MessageID: state.message.ID, + Type: "step-start", } - _ = stepStartPart // We'll add step tracking later + state.parts = append(state.parts, stepStartPart) + p.savePart(ctx, state.message.ID, stepStartPart) + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{Part: stepStartPart}, + }) + callback(state.message, state.parts) fmt.Printf("[stream] Starting to receive chunks\n") chunkCount := 0 + var lastChunkTime time.Time + var lastEventTime time.Time // For throttling event publishing for { select { @@ -60,13 +70,19 @@ func (p *Processor) processStream( return "error", err } chunkCount++ - fmt.Printf("[stream] Chunk %d: content=%q, toolCalls=%d, responseMeta=%v\n", - chunkCount, truncate(msg.Content, 50), len(msg.ToolCalls), msg.ResponseMeta != nil) + now := time.Now() + var delta time.Duration + if !lastChunkTime.IsZero() { + delta = now.Sub(lastChunkTime) + } + lastChunkTime = now + fmt.Printf("[stream] Chunk %d (+%v): content=%q, toolCalls=%d, responseMeta=%v\n", + chunkCount, delta, truncate(msg.Content, 50), len(msg.ToolCalls), msg.ResponseMeta != nil) // Process the message chunk finishReason = p.processMessageChunk(ctx, msg, state, callback, ¤tTextPart, ¤tReasoningPart, currentToolParts, - &accumulatedContent, accumulatedToolInputs) + &accumulatedContent, accumulatedToolInputs, &lastEventTime) if finishReason != "" { break @@ -105,12 +121,36 @@ func (p *Processor) processStream( // Determine finish reason from accumulated state if finishReason == "" { if len(currentToolParts) > 0 { - finishReason = "tool_use" + finishReason = "tool-calls" // SDK compatible: TypeScript uses "tool-calls" } else { finishReason = "stop" } } + // Normalize finish reason to SDK-compatible format + // TypeScript uses "tool-calls" but some providers return "tool_use" + if finishReason == "tool_use" { + finishReason = "tool-calls" + } + + // Emit step-finish part at the end of inference with cost and token info + stepFinishPart := &types.StepFinishPart{ + ID: generatePartID(), + SessionID: state.message.SessionID, + MessageID: state.message.ID, + Type: "step-finish", + Reason: finishReason, + Cost: state.message.Cost, + Tokens: state.message.Tokens, + } + state.parts = append(state.parts, stepFinishPart) + p.savePart(ctx, state.message.ID, stepFinishPart) + event.Publish(event.Event{ + Type: event.MessagePartUpdated, + Data: event.MessagePartUpdatedData{Part: stepFinishPart}, + }) + callback(state.message, state.parts) + fmt.Printf("[stream] Finished with reason=%s, parts=%d, tokens=%v\n", finishReason, len(state.parts), state.message.Tokens) @@ -125,6 +165,27 @@ func truncate(s string, n int) string { return s[:n] + "..." } +// MinEventInterval is the minimum time between streaming events. +// This ensures the TUI has time to process each event before the next arrives. +// Set to slightly above TUI's 16ms batching window to prevent batching. +const MinEventInterval = 20 * time.Millisecond + +// throttledPublish publishes an event with optional throttling to prevent TUI batching. +func throttledPublish(e event.Event, lastEventTime *time.Time) { + if lastEventTime != nil && !lastEventTime.IsZero() { + elapsed := time.Since(*lastEventTime) + if elapsed < MinEventInterval { + sleepTime := MinEventInterval - elapsed + fmt.Printf("[stream] THROTTLE sleep=%v (elapsed=%v)\n", sleepTime, elapsed) + time.Sleep(sleepTime) + } + } + event.Publish(e) + if lastEventTime != nil { + *lastEventTime = time.Now() + } +} + // processMessageChunk handles a single message chunk from the stream. func (p *Processor) processMessageChunk( ctx context.Context, @@ -136,6 +197,7 @@ func (p *Processor) processMessageChunk( currentToolParts map[string]*types.ToolPart, accumulatedContent *string, accumulatedToolInputs map[string]string, + lastEventTime *time.Time, ) string { var finishReason string @@ -158,13 +220,14 @@ func (p *Processor) processMessageChunk( // Publish delta event for FIRST chunk (SDK compatible) // This ensures the TUI receives and displays the first text chunk - event.Publish(event.Event{ + // Note: Uses throttledPublish to prevent TUI batching + throttledPublish(event.Event{ Type: event.MessagePartUpdated, Data: event.MessagePartUpdatedData{ Part: *currentTextPart, Delta: msg.Content, // First chunk IS the delta }, - }) + }, lastEventTime) callback(state.message, state.parts) } else { @@ -183,13 +246,14 @@ func (p *Processor) processMessageChunk( } // Publish delta event (SDK compatible: uses MessagePartUpdated) - event.Publish(event.Event{ + // Note: Uses throttledPublish to prevent TUI batching + throttledPublish(event.Event{ Type: event.MessagePartUpdated, Data: event.MessagePartUpdatedData{ Part: *currentTextPart, Delta: delta, }, - }) + }, lastEventTime) callback(state.message, state.parts) } @@ -282,6 +346,7 @@ func (p *Processor) processMessageChunk( } // Publish tool part update (SDK compatible: uses MessagePartUpdated) + // Note: Must use async Publish so SSE select loop can process events event.Publish(event.Event{ Type: event.MessagePartUpdated, Data: event.MessagePartUpdatedData{ diff --git a/go-opencode/internal/session/tools.go b/go-opencode/internal/session/tools.go index 7a9b2dff7a2..4e37009fe8f 100644 --- a/go-opencode/internal/session/tools.go +++ b/go-opencode/internal/session/tools.go @@ -3,13 +3,17 @@ package session import ( "context" "encoding/json" + "errors" "fmt" + "path/filepath" + "strings" "time" "github.com/opencode-ai/opencode/internal/event" "github.com/opencode-ai/opencode/internal/permission" "github.com/opencode-ai/opencode/internal/tool" "github.com/opencode-ai/opencode/pkg/types" + "github.com/sergi/go-diff/diffmatchpatch" ) // executeToolCalls executes all pending tool calls in the state. @@ -115,8 +119,13 @@ func (p *Processor) executeSingleTool( MessageID: state.message.ID, CallID: toolPart.CallID, Agent: agent.Name, - WorkDir: "", - AbortCh: abortCh, + WorkDir: func() string { + if state.message.Path != nil { + return state.message.Path.Cwd + } + return "" + }(), + AbortCh: abortCh, Extra: map[string]any{ "model": state.message.ModelID, }, @@ -181,6 +190,11 @@ func (p *Processor) executeSingleTool( } } + // Record diff for edit-like tools when metadata contains before/after + if err := p.recordDiff(state, toolPart); err != nil { + fmt.Printf("[tools] failed to record diff: %v\n", err) + } + // Save updated part p.savePart(ctx, state.message.ID, toolPart) @@ -220,7 +234,7 @@ func (p *Processor) failTool( }) callback(state.message, state.parts) - return fmt.Errorf(errMsg) + return errors.New(errMsg) } // checkToolPermission checks if the tool execution is permitted. @@ -284,6 +298,345 @@ func (p *Processor) checkToolPermission( return p.permissionChecker.Check(ctx, req, action) } +// recordDiff captures file diffs from tool metadata and updates session summary/state. +func (p *Processor) recordDiff(state *sessionState, toolPart *types.ToolPart) error { + if toolPart.State.Metadata == nil { + toolPart.State.Metadata = make(map[string]any) + } + + pathVal, ok := toolPart.State.Metadata["file"].(string) + if !ok || pathVal == "" { + return nil + } + + before, okBefore := toolPart.State.Metadata["before"].(string) + after, okAfter := toolPart.State.Metadata["after"].(string) + if !okBefore || !okAfter { + return nil + } + + root := "" + if state.message.Path != nil { + root = state.message.Path.Root + } + relPath := pathVal + if root != "" { + if rp, err := filepath.Rel(root, pathVal); err == nil { + relPath = rp + } + } + + diffText, additions, deletions, err := computeDiff(before, after, relPath) + if err != nil { + return err + } + + fileDiff := types.FileDiff{ + File: relPath, + Additions: additions, + Deletions: deletions, + Before: before, + After: after, + } + + // Load session to update summary + session, err := p.loadSession(state.message.SessionID) + if err != nil { + return err + } + + // Replace existing diff for same path, then append + var filtered []types.FileDiff + for _, d := range session.Summary.Diffs { + if d.File != relPath { + filtered = append(filtered, d) + } + } + filtered = append(filtered, fileDiff) + session.Summary.Diffs = filtered + + // Recompute summary totals + adds, dels, files := 0, 0, len(session.Summary.Diffs) + for _, d := range session.Summary.Diffs { + adds += d.Additions + dels += d.Deletions + } + session.Summary.Additions = adds + session.Summary.Deletions = dels + session.Summary.Files = files + session.Time.Updated = time.Now().UnixMilli() + + if err := p.saveSession(session); err != nil { + return err + } + + // Publish updated session diff + event.Publish(event.Event{ + Type: event.SessionDiff, + Data: event.SessionDiffData{SessionID: session.ID, Diff: session.Summary.Diffs}, + }) + + // Attach diff text to metadata for consumers (non-breaking) + toolPart.State.Metadata["diff"] = diffText + if toolPart.Metadata == nil { + toolPart.Metadata = map[string]any{} + } + toolPart.Metadata["diff"] = diffText + return nil +} + +func computeDiff(before, after, path string) (string, int, int, error) { + dmp := diffmatchpatch.New() + + // Compute line-based diff for accurate line counting + a, b, lineArray := dmp.DiffLinesToChars(before, after) + diffs := dmp.DiffMain(a, b, false) + diffs = dmp.DiffCharsToLines(diffs, lineArray) + + // Count additions and deletions by lines + additions, deletions := 0, 0 + for _, d := range diffs { + switch d.Type { + case diffmatchpatch.DiffInsert: + lines := countLines(d.Text) + additions += lines + case diffmatchpatch.DiffDelete: + lines := countLines(d.Text) + deletions += lines + } + } + + // Generate proper unified diff text for display + diffText := generateUnifiedDiff(diffs, path) + + return diffText, additions, deletions, nil +} + +// countLines counts the number of lines in text +func countLines(text string) int { + if text == "" { + return 0 + } + lines := strings.Count(text, "\n") + // If text doesn't end with newline, count it as a line + if !strings.HasSuffix(text, "\n") { + lines++ + } + return lines +} + +// generateUnifiedDiff creates a proper unified diff format from diffs with context lines +func generateUnifiedDiff(diffs []diffmatchpatch.Diff, path string) string { + if len(diffs) == 0 { + return "" + } + + // Check if there are any actual changes + hasChanges := false + for _, d := range diffs { + if d.Type != diffmatchpatch.DiffEqual { + hasChanges = true + break + } + } + if !hasChanges { + return "" + } + + // Convert diffs to lines with their types + type diffLine struct { + text string + diffType diffmatchpatch.Operation + } + var allLines []diffLine + + for _, d := range diffs { + text := d.Text + lines := strings.Split(text, "\n") + // Handle trailing newline - if text ends with \n, the last split element is empty + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + for _, line := range lines { + allLines = append(allLines, diffLine{text: line, diffType: d.Type}) + } + } + + // Find ranges of changes with context (3 lines before and after) + const contextLines = 3 + type hunk struct { + startOld, countOld int + startNew, countNew int + lines []diffLine + } + + var hunks []hunk + var currentHunk *hunk + oldLineNum := 1 + newLineNum := 1 + + for i, line := range allLines { + isChange := line.diffType != diffmatchpatch.DiffEqual + + if isChange { + // Start a new hunk or extend current one + if currentHunk == nil { + // Calculate start positions including context + contextStart := i - contextLines + if contextStart < 0 { + contextStart = 0 + } + + // Calculate old/new line numbers at context start + startOld := 1 + startNew := 1 + for j := 0; j < contextStart; j++ { + switch allLines[j].diffType { + case diffmatchpatch.DiffEqual: + startOld++ + startNew++ + case diffmatchpatch.DiffDelete: + startOld++ + case diffmatchpatch.DiffInsert: + startNew++ + } + } + + currentHunk = &hunk{ + startOld: startOld, + startNew: startNew, + } + + // Add context lines before the change + for j := contextStart; j < i; j++ { + currentHunk.lines = append(currentHunk.lines, allLines[j]) + } + } + currentHunk.lines = append(currentHunk.lines, line) + } else if currentHunk != nil { + // Check if we should end the hunk or continue with context + // Look ahead to see if there's another change within context range + nextChangeIdx := -1 + for j := i + 1; j < len(allLines) && j <= i+contextLines*2; j++ { + if allLines[j].diffType != diffmatchpatch.DiffEqual { + nextChangeIdx = j + break + } + } + + if nextChangeIdx != -1 && nextChangeIdx <= i+contextLines*2 { + // Another change is close, include this line and continue + currentHunk.lines = append(currentHunk.lines, line) + } else { + // Add remaining context lines and close hunk + for j := i; j < len(allLines) && j < i+contextLines; j++ { + if allLines[j].diffType == diffmatchpatch.DiffEqual { + currentHunk.lines = append(currentHunk.lines, allLines[j]) + } else { + break + } + } + + // Calculate counts + for _, l := range currentHunk.lines { + switch l.diffType { + case diffmatchpatch.DiffEqual: + currentHunk.countOld++ + currentHunk.countNew++ + case diffmatchpatch.DiffDelete: + currentHunk.countOld++ + case diffmatchpatch.DiffInsert: + currentHunk.countNew++ + } + } + + hunks = append(hunks, *currentHunk) + currentHunk = nil + } + } + + // Track line numbers + switch line.diffType { + case diffmatchpatch.DiffEqual: + oldLineNum++ + newLineNum++ + case diffmatchpatch.DiffDelete: + oldLineNum++ + case diffmatchpatch.DiffInsert: + newLineNum++ + } + } + + // Close any remaining hunk + if currentHunk != nil { + for _, l := range currentHunk.lines { + switch l.diffType { + case diffmatchpatch.DiffEqual: + currentHunk.countOld++ + currentHunk.countNew++ + case diffmatchpatch.DiffDelete: + currentHunk.countOld++ + case diffmatchpatch.DiffInsert: + currentHunk.countNew++ + } + } + hunks = append(hunks, *currentHunk) + } + + // Build output + var buf strings.Builder + + // Write file headers + buf.WriteString("Index: ") + buf.WriteString(path) + buf.WriteString("\n") + buf.WriteString("===================================================================\n") + buf.WriteString("--- ") + buf.WriteString(path) + buf.WriteString("\n") + buf.WriteString("+++ ") + buf.WriteString(path) + buf.WriteString("\n") + + // Write each hunk + for _, h := range hunks { + buf.WriteString(fmt.Sprintf("@@ -%d,%d +%d,%d @@\n", h.startOld, h.countOld, h.startNew, h.countNew)) + for _, line := range h.lines { + switch line.diffType { + case diffmatchpatch.DiffEqual: + buf.WriteString(" ") + case diffmatchpatch.DiffDelete: + buf.WriteString("-") + case diffmatchpatch.DiffInsert: + buf.WriteString("+") + } + buf.WriteString(line.text) + buf.WriteString("\n") + } + } + + return buf.String() +} + +func (p *Processor) loadSession(sessionID string) (*types.Session, error) { + projects, err := p.storage.List(context.Background(), []string{"session"}) + if err != nil { + return nil, err + } + + for _, projectID := range projects { + var session types.Session + if err := p.storage.Get(context.Background(), []string{"session", projectID, sessionID}, &session); err == nil { + return &session, nil + } + } + return nil, fmt.Errorf("session %s not found", sessionID) +} + +func (p *Processor) saveSession(session *types.Session) error { + return p.storage.Put(context.Background(), []string{"session", session.ProjectID, session.ID}, session) +} + // checkDoomLoop detects and handles repetitive tool calls. func (p *Processor) checkDoomLoop( ctx context.Context, diff --git a/go-opencode/internal/session/tools_test.go b/go-opencode/internal/session/tools_test.go new file mode 100644 index 00000000000..32cae76b564 --- /dev/null +++ b/go-opencode/internal/session/tools_test.go @@ -0,0 +1,250 @@ +package session + +import ( + "strings" + "testing" +) + +func TestComputeDiff_SingleLineChange(t *testing.T) { + before := `module github.com/opencode-ai/opencode + +go 1.25 + +require ( + github.com/example/pkg v1.0.0 +)` + + after := `module github.com/opencode-ai/opencode + +go 1.24 + +require ( + github.com/example/pkg v1.0.0 +)` + + diffText, additions, deletions, err := computeDiff(before, after, "go.mod") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The change from "go 1.25" to "go 1.24" should result in 1 addition and 1 deletion + if additions != 1 { + t.Errorf("expected 1 addition, got %d", additions) + } + if deletions != 1 { + t.Errorf("expected 1 deletion, got %d", deletions) + } + + // diffText should not be empty + if diffText == "" { + t.Error("expected non-empty diff text") + } +} + +func TestComputeDiff_MultipleLineChanges(t *testing.T) { + before := `line1 +line2 +line3` + + after := `line1 +modified2 +line3 +line4` + + _, additions, deletions, err := computeDiff(before, after, "test.txt") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The diff algorithm groups changes differently: + // - "line2\nline3" gets replaced with "modified2\nline3\nline4" + // - This results in 3 lines added and 2 lines deleted + // The important thing is that additions > 0 when there are additions + if additions == 0 { + t.Error("expected non-zero additions") + } + if deletions == 0 { + t.Error("expected non-zero deletions") + } + // Net change: +1 line (from 3 to 4 lines) + if additions-deletions != 1 { + t.Errorf("expected net change of +1, got %d", additions-deletions) + } +} + +func TestComputeDiff_NoChanges(t *testing.T) { + content := `same content +on multiple lines` + + diffText, additions, deletions, err := computeDiff(content, content, "file.txt") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if additions != 0 { + t.Errorf("expected 0 additions, got %d", additions) + } + if deletions != 0 { + t.Errorf("expected 0 deletions, got %d", deletions) + } + + // No changes means empty diff or only headers + // Either way, additions and deletions should be 0 + _ = diffText +} + +func TestComputeDiff_NewFile(t *testing.T) { + before := "" + after := `new content +with two lines` + + _, additions, deletions, err := computeDiff(before, after, "new.txt") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // New file with 2 lines = 2 additions + if additions != 2 { + t.Errorf("expected 2 additions, got %d", additions) + } + if deletions != 0 { + t.Errorf("expected 0 deletions, got %d", deletions) + } +} + +func TestComputeDiff_DeletedFile(t *testing.T) { + before := `content to delete +second line` + after := "" + + _, additions, deletions, err := computeDiff(before, after, "deleted.txt") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if additions != 0 { + t.Errorf("expected 0 additions, got %d", additions) + } + // Deleted file with 2 lines = 2 deletions + if deletions != 2 { + t.Errorf("expected 2 deletions, got %d", deletions) + } +} + +func TestComputeDiff_UnifiedDiffFormat(t *testing.T) { + before := `line1 +line2 +line3` + + after := `line1 +modified2 +line3` + + diffText, _, _, err := computeDiff(before, after, "test.txt") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + t.Logf("Diff output:\n%s", diffText) + + // The diff text should be in proper unified diff format + // Each deleted line should be prefixed with "-" on its own line + // Each added line should be prefixed with "+" on its own line + + // Check that diffText contains proper line-by-line format + // It should NOT have "-line2+modified2" on the same line + if diffText == "" { + t.Error("expected non-empty diff text") + } + + // CRITICAL: The diff should NOT contain URL-encoded characters like %0A + // The TUI expects raw newlines, not URL-encoded ones + if strings.Contains(diffText, "%0A") { + t.Error("diff should not contain URL-encoded newlines (%0A)") + } + if strings.Contains(diffText, "%0D") { + t.Error("diff should not contain URL-encoded carriage returns (%0D)") + } + + // Verify the diff has proper structure: + // - Should have "--- test.txt" or "--- a/test.txt" header + // - Should have "+++ test.txt" or "+++ b/test.txt" header + // - Should have "-line2" on its own line (not merged with +) + // - Should have "+modified2" on its own line + + lines := splitLines(diffText) + + hasMinusHeader := false + hasPlusHeader := false + foundDeletedLine := false + foundAddedLine := false + + for _, line := range lines { + if strings.HasPrefix(line, "--- ") { + hasMinusHeader = true + } + if strings.HasPrefix(line, "+++ ") { + hasPlusHeader = true + } + // Check for proper deleted line format (starts with - but not ---) + if len(line) > 1 && line[0] == '-' && line[1] != '-' { + foundDeletedLine = true + // Verify it's on its own line (doesn't contain + after the content) + if containsAddedMarker(line) { + t.Errorf("deleted line should not contain '+' marker: %q", line) + } + } + // Check for proper added line format (starts with + but not +++) + if len(line) > 1 && line[0] == '+' && line[1] != '+' { + foundAddedLine = true + } + } + + if !hasMinusHeader { + t.Errorf("diff should have '--- ' header line: %s", diffText) + } + if !hasPlusHeader { + t.Errorf("diff should have '+++ ' header line: %s", diffText) + } + if !foundDeletedLine { + t.Errorf("diff should contain deleted line starting with '-': %s", diffText) + } + if !foundAddedLine { + t.Errorf("diff should contain added line starting with '+': %s", diffText) + } +} + +// splitLines splits text by newlines, similar to strings.Split but handles edge cases +func splitLines(text string) []string { + if text == "" { + return nil + } + var lines []string + start := 0 + for i := 0; i < len(text); i++ { + if text[i] == '\n' { + lines = append(lines, text[start:i]) + start = i + 1 + } + } + if start < len(text) { + lines = append(lines, text[start:]) + } + return lines +} + +// containsAddedMarker checks if line contains a '+' that's not at the start +func containsAddedMarker(line string) bool { + for i := 1; i < len(line); i++ { + if line[i] == '+' { + return true + } + } + return false +} diff --git a/go-opencode/internal/tool/bash.go b/go-opencode/internal/tool/bash.go index af42cab50ee..37b7d8f2959 100644 --- a/go-opencode/internal/tool/bash.go +++ b/go-opencode/internal/tool/bash.go @@ -114,7 +114,7 @@ func detectShell() string { return "/bin/sh" } -func (t *BashTool) ID() string { return "Bash" } +func (t *BashTool) ID() string { return "bash" } func (t *BashTool) Description() string { return bashDescription } func (t *BashTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/bash_test.go b/go-opencode/internal/tool/bash_test.go index 49b78f6da14..b09513272ca 100644 --- a/go-opencode/internal/tool/bash_test.go +++ b/go-opencode/internal/tool/bash_test.go @@ -69,8 +69,8 @@ func TestBashTool_WithTimeout(t *testing.T) { func TestBashTool_Properties(t *testing.T) { tool := NewBashTool("/tmp") - if tool.ID() != "Bash" { - t.Errorf("Expected ID 'Bash', got %q", tool.ID()) + if tool.ID() != "bash" { + t.Errorf("Expected ID 'bash', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/edit.go b/go-opencode/internal/tool/edit.go index 84416f8b684..273f0c43cfa 100644 --- a/go-opencode/internal/tool/edit.go +++ b/go-opencode/internal/tool/edit.go @@ -41,7 +41,7 @@ func NewEditTool(workDir string) *EditTool { return &EditTool{workDir: workDir} } -func (t *EditTool) ID() string { return "Edit" } +func (t *EditTool) ID() string { return "edit" } func (t *EditTool) Description() string { return editDescription } func (t *EditTool) Parameters() json.RawMessage { @@ -130,6 +130,8 @@ func (t *EditTool) Execute(ctx context.Context, input json.RawMessage, toolCtx * Metadata: map[string]any{ "file": params.FilePath, "replacements": count, + "before": text, + "after": newText, }, }, nil } @@ -159,6 +161,11 @@ func (t *EditTool) fuzzyReplace(text string, params EditInput, toolCtx *Context) return &Result{ Title: fmt.Sprintf("Edited %s (normalized)", filepath.Base(params.FilePath)), Output: "Replaced 1 occurrence (with line ending normalization)", + Metadata: map[string]any{ + "file": params.FilePath, + "before": text, + "after": newText, + }, }, nil } @@ -183,6 +190,11 @@ func (t *EditTool) fuzzyReplace(text string, params EditInput, toolCtx *Context) return &Result{ Title: fmt.Sprintf("Edited %s (fuzzy)", filepath.Base(params.FilePath)), Output: fmt.Sprintf("Replaced 1 occurrence (%.0f%% similarity)", similarity*100), + Metadata: map[string]any{ + "file": params.FilePath, + "before": text, + "after": newText, + }, }, nil } diff --git a/go-opencode/internal/tool/edit_test.go b/go-opencode/internal/tool/edit_test.go index 644828724c5..e3875d24c49 100644 --- a/go-opencode/internal/tool/edit_test.go +++ b/go-opencode/internal/tool/edit_test.go @@ -206,8 +206,8 @@ func TestEditTool_FuzzyMatchSimilarity(t *testing.T) { func TestEditTool_Properties(t *testing.T) { tool := NewEditTool("/tmp") - if tool.ID() != "Edit" { - t.Errorf("Expected ID 'Edit', got %q", tool.ID()) + if tool.ID() != "edit" { + t.Errorf("Expected ID 'edit', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/glob.go b/go-opencode/internal/tool/glob.go index b0dde80f8ac..3534911f497 100644 --- a/go-opencode/internal/tool/glob.go +++ b/go-opencode/internal/tool/glob.go @@ -34,7 +34,7 @@ func NewGlobTool(workDir string) *GlobTool { return &GlobTool{workDir: workDir} } -func (t *GlobTool) ID() string { return "Glob" } +func (t *GlobTool) ID() string { return "glob" } func (t *GlobTool) Description() string { return globDescription } func (t *GlobTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/glob_test.go b/go-opencode/internal/tool/glob_test.go index 03960a9fa6a..602059f23cb 100644 --- a/go-opencode/internal/tool/glob_test.go +++ b/go-opencode/internal/tool/glob_test.go @@ -77,8 +77,8 @@ func TestGlobTool_NoMatches(t *testing.T) { func TestGlobTool_Properties(t *testing.T) { tool := NewGlobTool("/tmp") - if tool.ID() != "Glob" { - t.Errorf("Expected ID 'Glob', got %q", tool.ID()) + if tool.ID() != "glob" { + t.Errorf("Expected ID 'glob', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/grep.go b/go-opencode/internal/tool/grep.go index 8d57f75d2b8..ff8967ac0e9 100644 --- a/go-opencode/internal/tool/grep.go +++ b/go-opencode/internal/tool/grep.go @@ -35,7 +35,7 @@ func NewGrepTool(workDir string) *GrepTool { return &GrepTool{workDir: workDir} } -func (t *GrepTool) ID() string { return "Grep" } +func (t *GrepTool) ID() string { return "grep" } func (t *GrepTool) Description() string { return grepDescription } func (t *GrepTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/grep_test.go b/go-opencode/internal/tool/grep_test.go index facc34fe5bf..cf07d59372a 100644 --- a/go-opencode/internal/tool/grep_test.go +++ b/go-opencode/internal/tool/grep_test.go @@ -110,8 +110,8 @@ func TestGrepTool_WithGlobFilter(t *testing.T) { func TestGrepTool_Properties(t *testing.T) { tool := NewGrepTool("/tmp") - if tool.ID() != "Grep" { - t.Errorf("Expected ID 'Grep', got %q", tool.ID()) + if tool.ID() != "grep" { + t.Errorf("Expected ID 'grep', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/list.go b/go-opencode/internal/tool/list.go index a7d0deb92e9..b8c8c5b2f85 100644 --- a/go-opencode/internal/tool/list.go +++ b/go-opencode/internal/tool/list.go @@ -32,7 +32,7 @@ func NewListTool(workDir string) *ListTool { return &ListTool{workDir: workDir} } -func (t *ListTool) ID() string { return "List" } +func (t *ListTool) ID() string { return "list" } func (t *ListTool) Description() string { return listDescription } func (t *ListTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/list_test.go b/go-opencode/internal/tool/list_test.go index b193a0151f3..638bd915382 100644 --- a/go-opencode/internal/tool/list_test.go +++ b/go-opencode/internal/tool/list_test.go @@ -117,8 +117,8 @@ func TestListTool_EmptyDirectory(t *testing.T) { func TestListTool_Properties(t *testing.T) { tool := NewListTool("/tmp") - if tool.ID() != "List" { - t.Errorf("Expected ID 'List', got %q", tool.ID()) + if tool.ID() != "list" { + t.Errorf("Expected ID 'list', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/read.go b/go-opencode/internal/tool/read.go index 546ded04c61..865b89e09eb 100644 --- a/go-opencode/internal/tool/read.go +++ b/go-opencode/internal/tool/read.go @@ -40,7 +40,7 @@ func NewReadTool(workDir string) *ReadTool { return &ReadTool{workDir: workDir} } -func (t *ReadTool) ID() string { return "Read" } +func (t *ReadTool) ID() string { return "read" } func (t *ReadTool) Description() string { return readDescription } func (t *ReadTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/read_test.go b/go-opencode/internal/tool/read_test.go index de07b5ca260..949f8517555 100644 --- a/go-opencode/internal/tool/read_test.go +++ b/go-opencode/internal/tool/read_test.go @@ -78,8 +78,8 @@ func TestReadTool_WithOffsetAndLimit(t *testing.T) { func TestReadTool_Properties(t *testing.T) { tool := NewReadTool("/tmp") - if tool.ID() != "Read" { - t.Errorf("Expected ID 'Read', got %q", tool.ID()) + if tool.ID() != "read" { + t.Errorf("Expected ID 'read', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/internal/tool/task.go b/go-opencode/internal/tool/task.go index 90883be8af3..0db6d23cb6c 100644 --- a/go-opencode/internal/tool/task.go +++ b/go-opencode/internal/tool/task.go @@ -80,7 +80,7 @@ func (t *TaskTool) SetExecutor(executor TaskExecutor) { t.executor = executor } -func (t *TaskTool) ID() string { return "Task" } +func (t *TaskTool) ID() string { return "task" } func (t *TaskTool) Description() string { return taskDescription } func (t *TaskTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/write.go b/go-opencode/internal/tool/write.go index 5f2050e7647..4e03e3b95f7 100644 --- a/go-opencode/internal/tool/write.go +++ b/go-opencode/internal/tool/write.go @@ -36,7 +36,7 @@ func NewWriteTool(workDir string) *WriteTool { return &WriteTool{workDir: workDir} } -func (t *WriteTool) ID() string { return "Write" } +func (t *WriteTool) ID() string { return "write" } func (t *WriteTool) Description() string { return writeDescription } func (t *WriteTool) Parameters() json.RawMessage { diff --git a/go-opencode/internal/tool/write_test.go b/go-opencode/internal/tool/write_test.go index 73b1a39a4f9..2f7960b71b9 100644 --- a/go-opencode/internal/tool/write_test.go +++ b/go-opencode/internal/tool/write_test.go @@ -91,8 +91,8 @@ func TestWriteTool_Overwrite(t *testing.T) { func TestWriteTool_Properties(t *testing.T) { tool := NewWriteTool("/tmp") - if tool.ID() != "Write" { - t.Errorf("Expected ID 'Write', got %q", tool.ID()) + if tool.ID() != "write" { + t.Errorf("Expected ID 'write', got %q", tool.ID()) } desc := tool.Description() diff --git a/go-opencode/pkg/types/config.go b/go-opencode/pkg/types/config.go index f51a1aee7c9..e4dba22b811 100644 --- a/go-opencode/pkg/types/config.go +++ b/go-opencode/pkg/types/config.go @@ -16,6 +16,9 @@ type Config struct { // Theme (TUI only, for compatibility) Theme string `json:"theme,omitempty"` + // Keybinds (TUI shortcut configuration) + Keybinds Keybinds `json:"keybinds"` + // Sharing behavior Share string `json:"share,omitempty"` // "manual"|"auto"|"disabled" @@ -180,6 +183,243 @@ type ExperimentalConfig struct { BatchTool bool `json:"batch_tool,omitempty"` } +// Keybinds defines TUI keyboard shortcuts. Keep field order and names aligned +// with the TypeScript config schema for compatibility. +type Keybinds struct { + Leader string `json:"leader"` + AppExit string `json:"app_exit"` + EditorOpen string `json:"editor_open"` + ThemeList string `json:"theme_list"` + SidebarToggle string `json:"sidebar_toggle"` + UsernameToggle string `json:"username_toggle"` + StatusView string `json:"status_view"` + SessionExport string `json:"session_export"` + SessionNew string `json:"session_new"` + SessionList string `json:"session_list"` + SessionTimeline string `json:"session_timeline"` + SessionShare string `json:"session_share"` + SessionUnshare string `json:"session_unshare"` + SessionInterrupt string `json:"session_interrupt"` + SessionCompact string `json:"session_compact"` + MessagesPageUp string `json:"messages_page_up"` + MessagesPageDown string `json:"messages_page_down"` + MessagesHalfPageUp string `json:"messages_half_page_up"` + MessagesHalfPageDown string `json:"messages_half_page_down"` + MessagesFirst string `json:"messages_first"` + MessagesLast string `json:"messages_last"` + MessagesLastUser string `json:"messages_last_user"` + MessagesCopy string `json:"messages_copy"` + MessagesUndo string `json:"messages_undo"` + MessagesRedo string `json:"messages_redo"` + MessagesToggleConceal string `json:"messages_toggle_conceal"` + ToolDetails string `json:"tool_details"` + ModelList string `json:"model_list"` + ModelCycleRecent string `json:"model_cycle_recent"` + ModelCycleRecentReverse string `json:"model_cycle_recent_reverse"` + CommandList string `json:"command_list"` + AgentList string `json:"agent_list"` + AgentCycle string `json:"agent_cycle"` + AgentCycleReverse string `json:"agent_cycle_reverse"` + InputClear string `json:"input_clear"` + InputForwardDelete string `json:"input_forward_delete"` + InputPaste string `json:"input_paste"` + InputSubmit string `json:"input_submit"` + InputNewline string `json:"input_newline"` + HistoryPrevious string `json:"history_previous"` + HistoryNext string `json:"history_next"` + SessionChildCycle string `json:"session_child_cycle"` + SessionChildCycleReverse string `json:"session_child_cycle_reverse"` + TerminalSuspend string `json:"terminal_suspend"` +} + +// DefaultKeybinds returns the default TUI keybindings, matching the TypeScript implementation. +func DefaultKeybinds() Keybinds { + return Keybinds{ + Leader: "ctrl+x", + AppExit: "ctrl+c,ctrl+d,q", + EditorOpen: "e", + ThemeList: "t", + SidebarToggle: "b", + UsernameToggle: "none", + StatusView: "s", + SessionExport: "x", + SessionNew: "n", + SessionList: "l", + SessionTimeline: "g", + SessionShare: "none", + SessionUnshare: "none", + SessionInterrupt: "escape", + SessionCompact: "c", + MessagesPageUp: "pageup", + MessagesPageDown: "pagedown", + MessagesHalfPageUp: "ctrl+alt+u", + MessagesHalfPageDown: "ctrl+alt+d", + MessagesFirst: "ctrl+g,home", + MessagesLast: "ctrl+alt+g,end", + MessagesLastUser: "none", + MessagesCopy: "y", + MessagesUndo: "u", + MessagesRedo: "r", + MessagesToggleConceal: "h", + ToolDetails: "none", + ModelList: "m", + ModelCycleRecent: "f2", + ModelCycleRecentReverse: "shift+f2", + CommandList: "ctrl+p", + AgentList: "a", + AgentCycle: "tab", + AgentCycleReverse: "shift+tab", + InputClear: "ctrl+c", + InputForwardDelete: "ctrl+d", + InputPaste: "ctrl+v", + InputSubmit: "return", + InputNewline: "shift+return,ctrl+j", + HistoryPrevious: "up", + HistoryNext: "down", + SessionChildCycle: "right", + SessionChildCycleReverse: "left", + TerminalSuspend: "ctrl+z", + } +} + +// MergeKeybinds overlays overrides on top of base defaults, skipping empty values. +func MergeKeybinds(base, overrides Keybinds) Keybinds { + if overrides.Leader != "" { + base.Leader = overrides.Leader + } + if overrides.AppExit != "" { + base.AppExit = overrides.AppExit + } + if overrides.EditorOpen != "" { + base.EditorOpen = overrides.EditorOpen + } + if overrides.ThemeList != "" { + base.ThemeList = overrides.ThemeList + } + if overrides.SidebarToggle != "" { + base.SidebarToggle = overrides.SidebarToggle + } + if overrides.UsernameToggle != "" { + base.UsernameToggle = overrides.UsernameToggle + } + if overrides.StatusView != "" { + base.StatusView = overrides.StatusView + } + if overrides.SessionExport != "" { + base.SessionExport = overrides.SessionExport + } + if overrides.SessionNew != "" { + base.SessionNew = overrides.SessionNew + } + if overrides.SessionList != "" { + base.SessionList = overrides.SessionList + } + if overrides.SessionTimeline != "" { + base.SessionTimeline = overrides.SessionTimeline + } + if overrides.SessionShare != "" { + base.SessionShare = overrides.SessionShare + } + if overrides.SessionUnshare != "" { + base.SessionUnshare = overrides.SessionUnshare + } + if overrides.SessionInterrupt != "" { + base.SessionInterrupt = overrides.SessionInterrupt + } + if overrides.SessionCompact != "" { + base.SessionCompact = overrides.SessionCompact + } + if overrides.MessagesPageUp != "" { + base.MessagesPageUp = overrides.MessagesPageUp + } + if overrides.MessagesPageDown != "" { + base.MessagesPageDown = overrides.MessagesPageDown + } + if overrides.MessagesHalfPageUp != "" { + base.MessagesHalfPageUp = overrides.MessagesHalfPageUp + } + if overrides.MessagesHalfPageDown != "" { + base.MessagesHalfPageDown = overrides.MessagesHalfPageDown + } + if overrides.MessagesFirst != "" { + base.MessagesFirst = overrides.MessagesFirst + } + if overrides.MessagesLast != "" { + base.MessagesLast = overrides.MessagesLast + } + if overrides.MessagesLastUser != "" { + base.MessagesLastUser = overrides.MessagesLastUser + } + if overrides.MessagesCopy != "" { + base.MessagesCopy = overrides.MessagesCopy + } + if overrides.MessagesUndo != "" { + base.MessagesUndo = overrides.MessagesUndo + } + if overrides.MessagesRedo != "" { + base.MessagesRedo = overrides.MessagesRedo + } + if overrides.MessagesToggleConceal != "" { + base.MessagesToggleConceal = overrides.MessagesToggleConceal + } + if overrides.ToolDetails != "" { + base.ToolDetails = overrides.ToolDetails + } + if overrides.ModelList != "" { + base.ModelList = overrides.ModelList + } + if overrides.ModelCycleRecent != "" { + base.ModelCycleRecent = overrides.ModelCycleRecent + } + if overrides.ModelCycleRecentReverse != "" { + base.ModelCycleRecentReverse = overrides.ModelCycleRecentReverse + } + if overrides.CommandList != "" { + base.CommandList = overrides.CommandList + } + if overrides.AgentList != "" { + base.AgentList = overrides.AgentList + } + if overrides.AgentCycle != "" { + base.AgentCycle = overrides.AgentCycle + } + if overrides.AgentCycleReverse != "" { + base.AgentCycleReverse = overrides.AgentCycleReverse + } + if overrides.InputClear != "" { + base.InputClear = overrides.InputClear + } + if overrides.InputForwardDelete != "" { + base.InputForwardDelete = overrides.InputForwardDelete + } + if overrides.InputPaste != "" { + base.InputPaste = overrides.InputPaste + } + if overrides.InputSubmit != "" { + base.InputSubmit = overrides.InputSubmit + } + if overrides.InputNewline != "" { + base.InputNewline = overrides.InputNewline + } + if overrides.HistoryPrevious != "" { + base.HistoryPrevious = overrides.HistoryPrevious + } + if overrides.HistoryNext != "" { + base.HistoryNext = overrides.HistoryNext + } + if overrides.SessionChildCycle != "" { + base.SessionChildCycle = overrides.SessionChildCycle + } + if overrides.SessionChildCycleReverse != "" { + base.SessionChildCycleReverse = overrides.SessionChildCycleReverse + } + if overrides.TerminalSuspend != "" { + base.TerminalSuspend = overrides.TerminalSuspend + } + + return base +} + // Model represents an LLM model available from a provider. type Model struct { ID string `json:"id"` diff --git a/go-opencode/pkg/types/message.go b/go-opencode/pkg/types/message.go index d498fa2bcc0..53b799c447a 100644 --- a/go-opencode/pkg/types/message.go +++ b/go-opencode/pkg/types/message.go @@ -1,28 +1,102 @@ package types +import "encoding/json" + // Message represents either a User or Assistant message in a conversation. type Message struct { - ID string `json:"id"` - SessionID string `json:"sessionID"` - Role string `json:"role"` // "user" | "assistant" - Time MessageTime `json:"time"` + ID string `json:"id"` + SessionID string `json:"sessionID"` + Role string `json:"role"` // "user" | "assistant" + Time MessageTime `json:"time"` // User-specific fields - Agent string `json:"agent,omitempty"` - Model *ModelRef `json:"model,omitempty"` - System *string `json:"system,omitempty"` - Tools map[string]bool `json:"tools,omitempty"` + Agent string `json:"agent,omitempty"` + Model *ModelRef `json:"model,omitempty"` + System *string `json:"system,omitempty"` + Tools map[string]bool `json:"tools,omitempty"` + Summary *UserMessageSummary `json:"-"` // Summary with title and diffs (for user messages) // Assistant-specific fields + ParentID string `json:"parentID,omitempty"` // Links to the user message that prompted this ModelID string `json:"modelID,omitempty"` ProviderID string `json:"providerID,omitempty"` Mode string `json:"mode,omitempty"` // Agent name (e.g., "Coder", "Build") + Path *MessagePath `json:"path,omitempty"` // Current working directory and root + IsSummary bool `json:"-"` // True if this is a summary/compaction message (for assistant messages) Finish *string `json:"finish,omitempty"` Cost float64 `json:"cost"` // Required by TUI Tokens *TokenUsage `json:"tokens,omitempty"` Error *MessageError `json:"error,omitempty"` } +// MarshalJSON implements custom JSON marshaling to handle the summary field +// differently based on message role. +func (m Message) MarshalJSON() ([]byte, error) { + type Alias Message + aux := struct { + Alias + Summary any `json:"summary,omitempty"` + }{ + Alias: Alias(m), + } + + // Set the appropriate summary field based on role + if m.Role == "user" && m.Summary != nil { + aux.Summary = m.Summary + } else if m.Role == "assistant" && m.IsSummary { + aux.Summary = true + } + + return json.Marshal(aux) +} + +// UnmarshalJSON implements custom JSON unmarshaling to handle the summary field +// differently based on message role. +func (m *Message) UnmarshalJSON(data []byte) error { + type Alias Message + aux := struct { + *Alias + Summary json.RawMessage `json:"summary,omitempty"` + }{ + Alias: (*Alias)(m), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Parse summary based on role + if len(aux.Summary) > 0 { + if m.Role == "user" { + var summary UserMessageSummary + if err := json.Unmarshal(aux.Summary, &summary); err == nil { + m.Summary = &summary + } + } else if m.Role == "assistant" { + var isSummary bool + if err := json.Unmarshal(aux.Summary, &isSummary); err == nil { + m.IsSummary = isSummary + } + } + } + + return nil +} + +// MessagePath contains the current working directory and project root. +type MessagePath struct { + Cwd string `json:"cwd"` + Root string `json:"root"` +} + +// UserMessageSummary contains summary information for a user message. +// Uses FileDiff from session.go for diffs. +type UserMessageSummary struct { + Title string `json:"title,omitempty"` + Body string `json:"body,omitempty"` + Diffs []FileDiff `json:"diffs,omitempty"` +} + // MessageTime contains timestamps for a message. type MessageTime struct { Created int64 `json:"created"` diff --git a/go-opencode/pkg/types/parts.go b/go-opencode/pkg/types/parts.go index f9b7a42f830..11875a5a5e7 100644 --- a/go-opencode/pkg/types/parts.go +++ b/go-opencode/pkg/types/parts.go @@ -108,6 +108,54 @@ func (p *FilePart) PartID() string { return p.ID } func (p *FilePart) PartSessionID() string { return p.SessionID } func (p *FilePart) PartMessageID() string { return p.MessageID } +// StepStartPart marks the beginning of an inference step. +// SDK compatible: includes sessionID and messageID fields. +type StepStartPart struct { + ID string `json:"id"` + SessionID string `json:"sessionID"` // SDK compatible + MessageID string `json:"messageID"` // SDK compatible + Type string `json:"type"` // always "step-start" + Snapshot string `json:"snapshot,omitempty"` +} + +func (p *StepStartPart) PartType() string { return "step-start" } +func (p *StepStartPart) PartID() string { return p.ID } +func (p *StepStartPart) PartSessionID() string { return p.SessionID } +func (p *StepStartPart) PartMessageID() string { return p.MessageID } + +// StepFinishPart marks the end of an inference step with cost and token info. +// SDK compatible: includes sessionID and messageID fields. +type StepFinishPart struct { + ID string `json:"id"` + SessionID string `json:"sessionID"` // SDK compatible + MessageID string `json:"messageID"` // SDK compatible + Type string `json:"type"` // always "step-finish" + Reason string `json:"reason"` // e.g., "stop", "tool-calls" + Snapshot string `json:"snapshot,omitempty"` + Cost float64 `json:"cost"` + Tokens *TokenUsage `json:"tokens,omitempty"` +} + +func (p *StepFinishPart) PartType() string { return "step-finish" } +func (p *StepFinishPart) PartID() string { return p.ID } +func (p *StepFinishPart) PartSessionID() string { return p.SessionID } +func (p *StepFinishPart) PartMessageID() string { return p.MessageID } + +// CompactionPart represents a request to compact/summarize the conversation. +// SDK compatible: includes sessionID and messageID fields. +type CompactionPart struct { + ID string `json:"id"` + SessionID string `json:"sessionID"` // SDK compatible + MessageID string `json:"messageID"` // SDK compatible + Type string `json:"type"` // always "compaction" + Auto bool `json:"auto"` // Whether this was triggered automatically +} + +func (p *CompactionPart) PartType() string { return "compaction" } +func (p *CompactionPart) PartID() string { return p.ID } +func (p *CompactionPart) PartSessionID() string { return p.SessionID } +func (p *CompactionPart) PartMessageID() string { return p.MessageID } + // RawPart is used for JSON unmarshaling of parts. type RawPart struct { ID string `json:"id"` @@ -147,6 +195,24 @@ func UnmarshalPart(data []byte) (Part, error) { return nil, err } return &p, nil + case "step-start": + var p StepStartPart + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil + case "step-finish": + var p StepFinishPart + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil + case "compaction": + var p CompactionPart + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil default: // Return raw part for unknown types var p TextPart diff --git a/go-opencode/pkg/types/session.go b/go-opencode/pkg/types/session.go index 870d1ca031e..72140591200 100644 --- a/go-opencode/pkg/types/session.go +++ b/go-opencode/pkg/types/session.go @@ -3,17 +3,17 @@ package types // Session represents a conversation session with the LLM. type Session struct { - ID string `json:"id"` - ProjectID string `json:"projectID"` - Directory string `json:"directory"` - ParentID *string `json:"parentID,omitempty"` - Title string `json:"title"` - Version string `json:"version"` - Summary SessionSummary `json:"summary"` - Share *SessionShare `json:"share,omitempty"` - Time SessionTime `json:"time"` - Revert *SessionRevert `json:"revert,omitempty"` - CustomPrompt *CustomPrompt `json:"customPrompt,omitempty"` + ID string `json:"id"` + ProjectID string `json:"projectID"` + Directory string `json:"directory"` + ParentID *string `json:"parentID,omitempty"` + Title string `json:"title"` + Version string `json:"version"` + Summary SessionSummary `json:"summary"` + Share *SessionShare `json:"share,omitempty"` + Time SessionTime `json:"time"` + Revert *SessionRevert `json:"revert,omitempty"` + CustomPrompt *CustomPrompt `json:"customPrompt,omitempty"` } // SessionSummary contains statistics about code changes in a session. @@ -26,7 +26,7 @@ type SessionSummary struct { // FileDiff represents a diff for a single file. type FileDiff struct { - Path string `json:"path"` + File string `json:"file"` Additions int `json:"additions"` Deletions int `json:"deletions"` Before string `json:"before,omitempty"` @@ -79,10 +79,10 @@ const ( // Project represents a project (worktree) that can contain sessions. type Project struct { - ID string `json:"id"` - Worktree string `json:"worktree"` - VCS *string `json:"vcs,omitempty"` // "git" or nil - Time ProjectTime `json:"time"` + ID string `json:"id"` + Worktree string `json:"worktree"` + VCS *string `json:"vcs,omitempty"` // "git" or nil + Time ProjectTime `json:"time"` } // ProjectTime contains timestamps for a project. diff --git a/go-opencode/pkg/types/types_test.go b/go-opencode/pkg/types/types_test.go index 950a2265818..f8a31699ff6 100644 --- a/go-opencode/pkg/types/types_test.go +++ b/go-opencode/pkg/types/types_test.go @@ -79,12 +79,12 @@ func TestSession_OptionalFields(t *testing.T) { func TestMessage_JSON(t *testing.T) { msg := Message{ - ID: "msg-123", - SessionID: "session-456", - Role: "assistant", - ModelID: "claude-3-opus", + ID: "msg-123", + SessionID: "session-456", + Role: "assistant", + ModelID: "claude-3-opus", ProviderID: "anthropic", - Cost: 0.05, + Cost: 0.05, Tokens: &TokenUsage{ Input: 1000, Output: 500, @@ -162,7 +162,7 @@ func TestMessage_UserFields(t *testing.T) { func TestFileDiff_JSON(t *testing.T) { diff := FileDiff{ - Path: "/src/main.go", + File: "/src/main.go", Additions: 10, Deletions: 5, Before: "func old() {}", @@ -179,8 +179,8 @@ func TestFileDiff_JSON(t *testing.T) { t.Fatalf("Unmarshal failed: %v", err) } - if decoded.Path != diff.Path { - t.Errorf("Path mismatch: got %s, want %s", decoded.Path, diff.Path) + if decoded.File != diff.File { + t.Errorf("File mismatch: got %s, want %s", decoded.File, diff.File) } } @@ -251,3 +251,146 @@ func TestMessageError_JSON(t *testing.T) { t.Errorf("Name mismatch: got %s, want UnknownError", decoded.Name) } } + +func TestMessage_SummaryField_UserMessage(t *testing.T) { + // User message should have summary as an object + msg := Message{ + ID: "msg-user-1", + SessionID: "session-1", + Role: "user", + Agent: "main", + Summary: &UserMessageSummary{ + Title: "Fixed a bug", + Body: "Fixed the rendering issue", + Diffs: []FileDiff{{File: "main.go", Additions: 5, Deletions: 2}}, + }, + Time: MessageTime{Created: 1700000000000}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Verify summary is an object + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("Unmarshal to map failed: %v", err) + } + + summary, ok := raw["summary"].(map[string]any) + if !ok { + t.Fatalf("summary should be an object, got %T: %v", raw["summary"], raw["summary"]) + } + if summary["title"] != "Fixed a bug" { + t.Errorf("summary.title mismatch: got %v", summary["title"]) + } + + // Round-trip test + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if decoded.Summary == nil || decoded.Summary.Title != "Fixed a bug" { + t.Error("Summary not properly decoded") + } +} + +func TestMessage_SummaryField_AssistantMessage(t *testing.T) { + // Assistant message should have summary as a boolean + msg := Message{ + ID: "msg-assistant-1", + SessionID: "session-1", + Role: "assistant", + ParentID: "msg-user-1", + ModelID: "claude-3-opus", + ProviderID: "anthropic", + IsSummary: true, // This is a compaction summary message + Cost: 0.05, + Tokens: &TokenUsage{ + Input: 1000, + Output: 500, + Cache: CacheUsage{Read: 0, Write: 0}, + }, + Time: MessageTime{Created: 1700000000000}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Verify summary is a boolean + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("Unmarshal to map failed: %v", err) + } + + summary, ok := raw["summary"].(bool) + if !ok { + t.Fatalf("summary should be a boolean, got %T: %v", raw["summary"], raw["summary"]) + } + if !summary { + t.Error("summary should be true") + } + + // Round-trip test + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !decoded.IsSummary { + t.Error("IsSummary not properly decoded") + } +} + +func TestMessage_SummaryField_OmittedWhenNotSet(t *testing.T) { + // Test that summary is omitted when not set + msg := Message{ + ID: "msg-user-1", + SessionID: "session-1", + Role: "user", + Agent: "main", + Time: MessageTime{Created: 1700000000000}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("Unmarshal to map failed: %v", err) + } + + if _, ok := raw["summary"]; ok { + t.Error("summary should be omitted when not set") + } + + // Same for assistant without IsSummary + msg2 := Message{ + ID: "msg-assistant-1", + SessionID: "session-1", + Role: "assistant", + ParentID: "msg-user-1", + ModelID: "claude-3-opus", + ProviderID: "anthropic", + IsSummary: false, + Cost: 0.05, + Tokens: &TokenUsage{ + Input: 1000, + Output: 500, + Cache: CacheUsage{Read: 0, Write: 0}, + }, + Time: MessageTime{Created: 1700000000000}, + } + + data2, _ := json.Marshal(msg2) + var raw2 map[string]any + json.Unmarshal(data2, &raw2) + + if _, ok := raw2["summary"]; ok { + t.Error("summary should be omitted when IsSummary is false") + } +}