From 987c3ad18b7589f526c09f549dfbca72326e5ec9 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 28 Dec 2025 20:45:27 -0500 Subject: [PATCH] ide: add document symbols --- crates/squawk_ide/src/document_symbols.rs | 234 ++++++++++++++++++++++ crates/squawk_ide/src/lib.rs | 1 + crates/squawk_ide/src/resolve.rs | 26 ++- crates/squawk_server/src/lib.rs | 67 ++++++- 4 files changed, 314 insertions(+), 14 deletions(-) create mode 100644 crates/squawk_ide/src/document_symbols.rs diff --git a/crates/squawk_ide/src/document_symbols.rs b/crates/squawk_ide/src/document_symbols.rs new file mode 100644 index 00000000..65ae7362 --- /dev/null +++ b/crates/squawk_ide/src/document_symbols.rs @@ -0,0 +1,234 @@ +use rowan::TextRange; +use squawk_syntax::ast::{self, AstNode}; + +use crate::binder; +use crate::resolve::{resolve_function_info, resolve_table_info}; + +pub enum DocumentSymbolKind { + Table, + Function, +} + +pub struct DocumentSymbol { + pub name: String, + pub detail: Option, + pub kind: DocumentSymbolKind, + pub range: TextRange, + pub selection_range: TextRange, +} + +pub fn document_symbols(file: &ast::SourceFile) -> Vec { + let binder = binder::bind(file); + let mut symbols = vec![]; + + for stmt in file.stmts() { + match stmt { + ast::Stmt::CreateTable(create_table) => { + if let Some(symbol) = create_table_symbol(&binder, create_table) { + symbols.push(symbol); + } + } + ast::Stmt::CreateFunction(create_function) => { + if let Some(symbol) = create_function_symbol(&binder, create_function) { + symbols.push(symbol); + } + } + _ => {} + } + } + + symbols +} + +fn create_table_symbol( + binder: &binder::Binder, + create_table: ast::CreateTable, +) -> Option { + let path = create_table.path()?; + let segment = path.segment()?; + let name_node = segment.name()?; + + let (schema, table_name) = resolve_table_info(binder, &path)?; + let name = format!("{}.{}", schema.0, table_name); + + let range = create_table.syntax().text_range(); + let selection_range = name_node.syntax().text_range(); + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::Table, + range, + selection_range, + }) +} + +fn create_function_symbol( + binder: &binder::Binder, + create_function: ast::CreateFunction, +) -> Option { + let path = create_function.path()?; + let segment = path.segment()?; + let name_node = segment.name()?; + + let (schema, function_name) = resolve_function_info(binder, &path)?; + let name = format!("{}.{}", schema.0, function_name); + + let range = create_function.syntax().text_range(); + let selection_range = name_node.syntax().text_range(); + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::Function, + range, + selection_range, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle}; + use insta::assert_snapshot; + + fn symbols_not_found(sql: &str) { + let parse = ast::SourceFile::parse(sql); + let file = parse.tree(); + let symbols = document_symbols(&file); + if !symbols.is_empty() { + panic!("Symbols found. If this is expected, use `symbols` instead.") + } + } + + fn symbols(sql: &str) -> String { + let parse = ast::SourceFile::parse(sql); + let file = parse.tree(); + let symbols = document_symbols(&file); + if symbols.is_empty() { + panic!("No symbols found. If this is expected, use `symbols_not_found` instead.") + } + + let mut groups = vec![]; + for symbol in symbols { + let kind = match symbol.kind { + DocumentSymbolKind::Table => "table", + DocumentSymbolKind::Function => "function", + }; + let title = format!("{}: {}", kind, symbol.name); + let group = Level::INFO.primary_title(title).element( + Snippet::source(sql) + .fold(true) + .annotation( + AnnotationKind::Primary + .span(symbol.selection_range.into()) + .label("name"), + ) + .annotation( + AnnotationKind::Context + .span(symbol.range.into()) + .label("select range"), + ), + ); + groups.push(group); + } + + let renderer = Renderer::plain().decor_style(DecorStyle::Unicode); + renderer.render(&groups).to_string() + } + + #[test] + fn create_table() { + assert_snapshot!(symbols("create table users (id int);"), @r" + info: table: public.users + ╭▸ + 1 │ create table users (id int); + │ ┬────────────┯━━━━───────── + │ │ │ + │ │ name + ╰╴select range + "); + } + + #[test] + fn create_function() { + assert_snapshot!( + symbols("create function hello() returns void as $$ select 1; $$ language sql;"), + @r" + info: function: public.hello + ╭▸ + 1 │ create function hello() returns void as $$ select 1; $$ language sql; + │ ┬───────────────┯━━━━─────────────────────────────────────────────── + │ │ │ + │ │ name + ╰╴select range + " + ); + } + + #[test] + fn multiple_symbols() { + assert_snapshot!(symbols(" +create table users (id int); +create table posts (id int); +create function get_user(user_id int) returns void as $$ select 1; $$ language sql; +"), @r" + info: table: public.users + ╭▸ + 2 │ create table users (id int); + │ ┬────────────┯━━━━───────── + │ │ │ + │ │ name + │ select range + ╰╴ + info: table: public.posts + ╭▸ + 3 │ create table posts (id int); + │ ┬────────────┯━━━━───────── + │ │ │ + │ │ name + ╰╴select range + info: function: public.get_user + ╭▸ + 4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql; + │ ┬───────────────┯━━━━━━━────────────────────────────────────────────────────────── + │ │ │ + │ │ name + ╰╴select range + "); + } + + #[test] + fn qualified_names() { + assert_snapshot!(symbols(" +create table public.users (id int); +create function my_schema.hello() returns void as $$ select 1; $$ language sql; +"), @r" + info: table: public.users + ╭▸ + 2 │ create table public.users (id int); + │ ┬───────────────────┯━━━━───────── + │ │ │ + │ │ name + │ select range + ╰╴ + info: function: my_schema.hello + ╭▸ + 3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql; + │ ┬─────────────────────────┯━━━━─────────────────────────────────────────────── + │ │ │ + │ │ name + ╰╴select range + "); + } + + #[test] + fn empty_file() { + symbols_not_found("") + } + + #[test] + fn non_create_statements() { + symbols_not_found("select * from users;") + } +} diff --git a/crates/squawk_ide/src/lib.rs b/crates/squawk_ide/src/lib.rs index 4aeca3ad..cd59c0d1 100644 --- a/crates/squawk_ide/src/lib.rs +++ b/crates/squawk_ide/src/lib.rs @@ -1,6 +1,7 @@ mod binder; pub mod code_actions; pub mod column_name; +pub mod document_symbols; pub mod expand_selection; pub mod find_references; mod generated; diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index dc3b4b6b..3c585a72 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -779,30 +779,42 @@ pub(crate) fn resolve_insert_table_columns( } pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { - let table_name_str = extract_table_name_from_path(path)?; + resolve_symbol_info(binder, path, SymbolKind::Table) +} + +pub(crate) fn resolve_function_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { + resolve_symbol_info(binder, path, SymbolKind::Function) +} + +fn resolve_symbol_info( + binder: &Binder, + path: &ast::Path, + kind: SymbolKind, +) -> Option<(Schema, String)> { + let name_str = extract_table_name_from_path(path)?; let schema = extract_schema_from_path(path); - let table_name_normalized = Name::new(table_name_str.clone()); - let symbols = binder.scopes[binder.root_scope()].get(&table_name_normalized)?; + let name_normalized = Name::new(name_str.clone()); + let symbols = binder.scopes[binder.root_scope()].get(&name_normalized)?; if let Some(schema_name) = schema { let schema_normalized = Schema::new(schema_name); let symbol_id = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Table && symbol.schema == schema_normalized + symbol.kind == kind && symbol.schema == schema_normalized })?; let symbol = &binder.symbols[symbol_id]; - return Some((symbol.schema.clone(), table_name_str)); + return Some((symbol.schema.clone(), name_str)); } else { let position = path.syntax().text_range().start(); let search_path = binder.search_path_at(position); for search_schema in search_path { if let Some(symbol_id) = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Table && &symbol.schema == search_schema + symbol.kind == kind && &symbol.schema == search_schema }) { let symbol = &binder.symbols[symbol_id]; - return Some((symbol.schema.clone(), table_name_str)); + return Some((symbol.schema.clone(), name_str)); } } } diff --git a/crates/squawk_server/src/lib.rs b/crates/squawk_server/src/lib.rs index 6613b9bd..d6bf303e 100644 --- a/crates/squawk_server/src/lib.rs +++ b/crates/squawk_server/src/lib.rs @@ -6,23 +6,24 @@ use lsp_types::{ CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams, CodeActionProviderCapability, CodeActionResponse, Command, Diagnostic, DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams, - GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverContents, HoverParams, - HoverProviderCapability, InitializeParams, InlayHint, InlayHintKind, InlayHintLabel, - InlayHintLabelPart, InlayHintParams, LanguageString, Location, MarkedString, OneOf, - PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams, - SelectionRangeProviderCapability, ServerCapabilities, TextDocumentSyncCapability, + DocumentSymbol, DocumentSymbolParams, GotoDefinitionParams, GotoDefinitionResponse, Hover, + HoverContents, HoverParams, HoverProviderCapability, InitializeParams, InlayHint, + InlayHintKind, InlayHintLabel, InlayHintLabelPart, InlayHintParams, LanguageString, Location, + MarkedString, OneOf, PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams, + SelectionRangeProviderCapability, ServerCapabilities, SymbolKind, TextDocumentSyncCapability, TextDocumentSyncKind, Url, WorkDoneProgressOptions, WorkspaceEdit, notification::{ DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _, PublishDiagnostics, }, request::{ - CodeActionRequest, GotoDefinition, HoverRequest, InlayHintRequest, References, Request, - SelectionRangeRequest, + CodeActionRequest, DocumentSymbolRequest, GotoDefinition, HoverRequest, InlayHintRequest, + References, Request, SelectionRangeRequest, }, }; use rowan::TextRange; use squawk_ide::code_actions::code_actions; +use squawk_ide::document_symbols::{DocumentSymbolKind, document_symbols}; use squawk_ide::find_references::find_references; use squawk_ide::goto_definition::goto_definition; use squawk_ide::hover::hover; @@ -67,6 +68,7 @@ pub fn run() -> Result<()> { definition_provider: Some(OneOf::Left(true)), hover_provider: Some(HoverProviderCapability::Simple(true)), inlay_hint_provider: Some(OneOf::Left(true)), + document_symbol_provider: Some(OneOf::Left(true)), ..Default::default() }) .unwrap(); @@ -119,6 +121,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { InlayHintRequest::METHOD => { handle_inlay_hints(&connection, req, &documents)?; } + DocumentSymbolRequest::METHOD => { + handle_document_symbol(&connection, req, &documents)?; + } "squawk/syntaxTree" => { handle_syntax_tree(&connection, req, &documents)?; } @@ -296,6 +301,54 @@ fn handle_inlay_hints( Ok(()) } +fn handle_document_symbol( + connection: &Connection, + req: lsp_server::Request, + documents: &HashMap, +) -> Result<()> { + let params: DocumentSymbolParams = serde_json::from_value(req.params)?; + let uri = params.text_document.uri; + + let content = documents.get(&uri).map_or("", |doc| &doc.content); + let parse = SourceFile::parse(content); + let file = parse.tree(); + let line_index = LineIndex::new(content); + + let symbols = document_symbols(&file); + + let lsp_symbols: Vec = symbols + .into_iter() + .map(|sym| { + let range = lsp_utils::range(&line_index, sym.range); + let selection_range = lsp_utils::range(&line_index, sym.selection_range); + + DocumentSymbol { + name: sym.name, + detail: sym.detail, + kind: match sym.kind { + DocumentSymbolKind::Table => SymbolKind::STRUCT, + DocumentSymbolKind::Function => SymbolKind::FUNCTION, + }, + tags: None, + range, + selection_range, + children: None, + #[allow(deprecated)] + deprecated: None, + } + }) + .collect(); + + let resp = Response { + id: req.id, + result: Some(serde_json::to_value(&lsp_symbols).unwrap()), + error: None, + }; + + connection.sender.send(Message::Response(resp))?; + Ok(()) +} + fn handle_selection_range( connection: &Connection, req: lsp_server::Request,