diff --git a/crates/squawk_ide/src/document_symbols.rs b/crates/squawk_ide/src/document_symbols.rs index 65ae7362..a617a1bf 100644 --- a/crates/squawk_ide/src/document_symbols.rs +++ b/crates/squawk_ide/src/document_symbols.rs @@ -4,17 +4,24 @@ use squawk_syntax::ast::{self, AstNode}; use crate::binder; use crate::resolve::{resolve_function_info, resolve_table_info}; +#[derive(Debug)] pub enum DocumentSymbolKind { Table, Function, + Column, } +#[derive(Debug)] pub struct DocumentSymbol { pub name: String, pub detail: Option, pub kind: DocumentSymbolKind, - pub range: TextRange, - pub selection_range: TextRange, + /// Range used for determining when cursor is inside the symbol for showing + /// in the UI + pub full_range: TextRange, + /// Range selected when symbol is selected + pub focus_range: TextRange, + pub children: Vec, } pub fn document_symbols(file: &ast::SourceFile) -> Vec { @@ -51,15 +58,27 @@ fn create_table_symbol( let (schema, table_name) = resolve_table_info(binder, &path)?; let name = format!("{}.{}", schema.0, table_name); - let range = create_table.syntax().text_range(); - let selection_range = name_node.syntax().text_range(); + let full_range = create_table.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); + + let mut children = vec![]; + if let Some(table_arg_list) = create_table.table_arg_list() { + for arg in table_arg_list.args() { + if let ast::TableArg::Column(column) = arg + && let Some(column_symbol) = create_column_symbol(column) + { + children.push(column_symbol); + } + } + } Some(DocumentSymbol { name, detail: None, kind: DocumentSymbolKind::Table, - range, - selection_range, + full_range, + focus_range, + children, }) } @@ -74,22 +93,44 @@ fn create_function_symbol( let (schema, function_name) = resolve_function_info(binder, &path)?; let name = format!("{}.{}", schema.0, function_name); - let range = create_function.syntax().text_range(); - let selection_range = name_node.syntax().text_range(); + let full_range = create_function.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); Some(DocumentSymbol { name, detail: None, kind: DocumentSymbolKind::Function, - range, - selection_range, + full_range, + focus_range, + children: vec![], + }) +} + +fn create_column_symbol(column: ast::Column) -> Option { + let name_node = column.name()?; + let name = name_node.syntax().text().to_string(); + + let detail = column.ty().map(|t| t.syntax().text().to_string()); + + let full_range = column.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); + + Some(DocumentSymbol { + name, + detail, + kind: DocumentSymbolKind::Column, + full_range, + focus_range, + children: vec![], }) } #[cfg(test)] mod tests { use super::*; - use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle}; + use annotate_snippets::{ + AnnotationKind, Group, Level, Renderer, Snippet, renderer::DecorStyle, + }; use insta::assert_snapshot; fn symbols_not_found(sql: &str) { @@ -109,44 +150,110 @@ mod tests { panic!("No symbols found. If this is expected, use `symbols_not_found` instead.") } - let mut groups = vec![]; + let mut output = vec![]; for symbol in symbols { - let kind = match symbol.kind { - DocumentSymbolKind::Table => "table", - DocumentSymbolKind::Function => "function", - }; - let title = format!("{}: {}", kind, symbol.name); - let group = Level::INFO.primary_title(title).element( - Snippet::source(sql) - .fold(true) + let group = symbol_to_group(&symbol, sql); + output.push(group); + } + Renderer::plain() + .decor_style(DecorStyle::Unicode) + .render(&output) + .to_string() + } + + fn symbol_to_group<'a>(symbol: &DocumentSymbol, sql: &'a str) -> Group<'a> { + let kind = match symbol.kind { + DocumentSymbolKind::Table => "table", + DocumentSymbolKind::Function => "function", + DocumentSymbolKind::Column => "column", + }; + + let title = if let Some(detail) = &symbol.detail { + format!("{}: {} {}", kind, symbol.name, detail) + } else { + format!("{}: {}", kind, symbol.name) + }; + + let snippet = Snippet::source(sql) + .fold(true) + .annotation( + AnnotationKind::Primary + .span(symbol.focus_range.into()) + .label("focus range"), + ) + .annotation( + AnnotationKind::Context + .span(symbol.full_range.into()) + .label("full range"), + ); + + let mut group = Level::INFO.primary_title(title.clone()).element(snippet); + + if !symbol.children.is_empty() { + let child_labels: Vec = symbol + .children + .iter() + .map(|child| { + let kind = match child.kind { + DocumentSymbolKind::Column => "column", + _ => unreachable!("only columns can be children"), + }; + let detail = &child.detail.as_ref().unwrap(); + format!("{}: {} {}", kind, child.name, detail) + }) + .collect(); + + let mut children_snippet = Snippet::source(sql).fold(true); + + for (i, child) in symbol.children.iter().enumerate() { + children_snippet = children_snippet .annotation( - AnnotationKind::Primary - .span(symbol.selection_range.into()) - .label("name"), + AnnotationKind::Context + .span(child.full_range.into()) + .label(format!("full range for `{}`", child_labels[i].clone())), ) .annotation( - AnnotationKind::Context - .span(symbol.range.into()) - .label("select range"), - ), - ); - groups.push(group); + AnnotationKind::Primary + .span(child.focus_range.into()) + .label("focus range"), + ); + } + + group = group.element(children_snippet); } - let renderer = Renderer::plain().decor_style(DecorStyle::Unicode); - renderer.render(&groups).to_string() + group } #[test] fn create_table() { - assert_snapshot!(symbols("create table users (id int);"), @r" + assert_snapshot!(symbols(" +create table users ( + id int, + email citext +);"), @r" info: table: public.users ╭▸ - 1 │ create table users (id int); - │ ┬────────────┯━━━━───────── - │ │ │ - │ │ name - ╰╴select range + 2 │ create table users ( + │ │ ━━━━━ focus range + │ ┌─┘ + │ │ + 3 │ │ id int, + 4 │ │ email citext + 5 │ │ ); + │ └─┘ full range + │ + ⸬ + 3 │ id int, + │ ┯━──── + │ │ + │ full range for `column: id int` + │ focus range + 4 │ email citext + │ ┯━━━━─────── + │ │ + │ full range for `column: email citext` + ╰╴ focus range "); } @@ -160,8 +267,8 @@ mod tests { 1 │ create function hello() returns void as $$ select 1; $$ language sql; │ ┬───────────────┯━━━━─────────────────────────────────────────────── │ │ │ - │ │ name - ╰╴select range + │ │ focus range + ╰╴full range " ); } @@ -178,23 +285,37 @@ create function get_user(user_id int) returns void as $$ select 1; $$ language s 2 │ create table users (id int); │ ┬────────────┯━━━━───────── │ │ │ - │ │ name - │ select range + │ │ focus range + │ full range + │ + ⸬ + 2 │ create table users (id int); + │ ┯━──── + │ │ + │ full range for `column: id int` + │ focus range ╰╴ info: table: public.posts ╭▸ 3 │ create table posts (id int); │ ┬────────────┯━━━━───────── │ │ │ - │ │ name - ╰╴select range + │ │ focus range + │ full range + │ + ⸬ + 3 │ create table posts (id int); + │ ┯━──── + │ │ + │ full range for `column: id int` + ╰╴ focus range info: function: public.get_user ╭▸ 4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql; │ ┬───────────────┯━━━━━━━────────────────────────────────────────────────────────── │ │ │ - │ │ name - ╰╴select range + │ │ focus range + ╰╴full range "); } @@ -209,16 +330,23 @@ create function my_schema.hello() returns void as $$ select 1; $$ language sql; 2 │ create table public.users (id int); │ ┬───────────────────┯━━━━───────── │ │ │ - │ │ name - │ select range + │ │ focus range + │ full range + │ + ⸬ + 2 │ create table public.users (id int); + │ ┯━──── + │ │ + │ full range for `column: id int` + │ focus range ╰╴ info: function: my_schema.hello ╭▸ 3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql; │ ┬─────────────────────────┯━━━━─────────────────────────────────────────────── │ │ │ - │ │ name - ╰╴select range + │ │ focus range + ╰╴full range "); } diff --git a/crates/squawk_server/src/lib.rs b/crates/squawk_server/src/lib.rs index d6bf303e..81b5e3e2 100644 --- a/crates/squawk_server/src/lib.rs +++ b/crates/squawk_server/src/lib.rs @@ -316,27 +316,41 @@ fn handle_document_symbol( let symbols = document_symbols(&file); + fn convert_symbol( + sym: squawk_ide::document_symbols::DocumentSymbol, + line_index: &LineIndex, + ) -> DocumentSymbol { + let range = lsp_utils::range(line_index, sym.full_range); + let selection_range = lsp_utils::range(line_index, sym.focus_range); + + let children = sym + .children + .into_iter() + .map(|child| convert_symbol(child, line_index)) + .collect::>(); + + let children = (!children.is_empty()).then_some(children); + + DocumentSymbol { + name: sym.name, + detail: sym.detail, + kind: match sym.kind { + DocumentSymbolKind::Table => SymbolKind::STRUCT, + DocumentSymbolKind::Function => SymbolKind::FUNCTION, + DocumentSymbolKind::Column => SymbolKind::FIELD, + }, + tags: None, + range, + selection_range, + children, + #[allow(deprecated)] + deprecated: None, + } + } + let lsp_symbols: Vec = symbols .into_iter() - .map(|sym| { - let range = lsp_utils::range(&line_index, sym.range); - let selection_range = lsp_utils::range(&line_index, sym.selection_range); - - DocumentSymbol { - name: sym.name, - detail: sym.detail, - kind: match sym.kind { - DocumentSymbolKind::Table => SymbolKind::STRUCT, - DocumentSymbolKind::Function => SymbolKind::FUNCTION, - }, - tags: None, - range, - selection_range, - children: None, - #[allow(deprecated)] - deprecated: None, - } - }) + .map(|sym| convert_symbol(sym, &line_index)) .collect(); let resp = Response {