Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions crates/squawk_ide/src/document_symbols.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
use rowan::TextRange;
use squawk_syntax::ast::{self, AstNode};

use crate::binder;
use crate::resolve::{resolve_function_info, resolve_table_info};

pub enum DocumentSymbolKind {
Table,
Function,
}

pub struct DocumentSymbol {
pub name: String,
pub detail: Option<String>,
pub kind: DocumentSymbolKind,
pub range: TextRange,
pub selection_range: TextRange,
}

pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> {
let binder = binder::bind(file);
let mut symbols = vec![];

for stmt in file.stmts() {
match stmt {
ast::Stmt::CreateTable(create_table) => {
if let Some(symbol) = create_table_symbol(&binder, create_table) {
symbols.push(symbol);
}
}
ast::Stmt::CreateFunction(create_function) => {
if let Some(symbol) = create_function_symbol(&binder, create_function) {
symbols.push(symbol);
}
}
_ => {}
}
}

symbols
}

fn create_table_symbol(
binder: &binder::Binder,
create_table: ast::CreateTable,
) -> Option<DocumentSymbol> {
let path = create_table.path()?;
let segment = path.segment()?;
let name_node = segment.name()?;

let (schema, table_name) = resolve_table_info(binder, &path)?;
let name = format!("{}.{}", schema.0, table_name);

let range = create_table.syntax().text_range();
let selection_range = name_node.syntax().text_range();

Some(DocumentSymbol {
name,
detail: None,
kind: DocumentSymbolKind::Table,
range,
selection_range,
})
}

fn create_function_symbol(
binder: &binder::Binder,
create_function: ast::CreateFunction,
) -> Option<DocumentSymbol> {
let path = create_function.path()?;
let segment = path.segment()?;
let name_node = segment.name()?;

let (schema, function_name) = resolve_function_info(binder, &path)?;
let name = format!("{}.{}", schema.0, function_name);

let range = create_function.syntax().text_range();
let selection_range = name_node.syntax().text_range();

Some(DocumentSymbol {
name,
detail: None,
kind: DocumentSymbolKind::Function,
range,
selection_range,
})
}

#[cfg(test)]
mod tests {
use super::*;
use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
use insta::assert_snapshot;

fn symbols_not_found(sql: &str) {
let parse = ast::SourceFile::parse(sql);
let file = parse.tree();
let symbols = document_symbols(&file);
if !symbols.is_empty() {
panic!("Symbols found. If this is expected, use `symbols` instead.")
}
}

fn symbols(sql: &str) -> String {
let parse = ast::SourceFile::parse(sql);
let file = parse.tree();
let symbols = document_symbols(&file);
if symbols.is_empty() {
panic!("No symbols found. If this is expected, use `symbols_not_found` instead.")
}

let mut groups = vec![];
for symbol in symbols {
let kind = match symbol.kind {
DocumentSymbolKind::Table => "table",
DocumentSymbolKind::Function => "function",
};
let title = format!("{}: {}", kind, symbol.name);
let group = Level::INFO.primary_title(title).element(
Snippet::source(sql)
.fold(true)
.annotation(
AnnotationKind::Primary
.span(symbol.selection_range.into())
.label("name"),
)
.annotation(
AnnotationKind::Context
.span(symbol.range.into())
.label("select range"),
),
);
groups.push(group);
}

let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
renderer.render(&groups).to_string()
}

#[test]
fn create_table() {
assert_snapshot!(symbols("create table users (id int);"), @r"
info: table: public.users
╭▸
1 │ create table users (id int);
│ ┬────────────┯━━━━─────────
│ │ │
│ │ name
╰╴select range
");
}

#[test]
fn create_function() {
assert_snapshot!(
symbols("create function hello() returns void as $$ select 1; $$ language sql;"),
@r"
info: function: public.hello
╭▸
1 │ create function hello() returns void as $$ select 1; $$ language sql;
│ ┬───────────────┯━━━━───────────────────────────────────────────────
│ │ │
│ │ name
╰╴select range
"
);
}

#[test]
fn multiple_symbols() {
assert_snapshot!(symbols("
create table users (id int);
create table posts (id int);
create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
"), @r"
info: table: public.users
╭▸
2 │ create table users (id int);
│ ┬────────────┯━━━━─────────
│ │ │
│ │ name
│ select range
╰╴
info: table: public.posts
╭▸
3 │ create table posts (id int);
│ ┬────────────┯━━━━─────────
│ │ │
│ │ name
╰╴select range
info: function: public.get_user
╭▸
4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
│ ┬───────────────┯━━━━━━━──────────────────────────────────────────────────────────
│ │ │
│ │ name
╰╴select range
");
}

#[test]
fn qualified_names() {
assert_snapshot!(symbols("
create table public.users (id int);
create function my_schema.hello() returns void as $$ select 1; $$ language sql;
"), @r"
info: table: public.users
╭▸
2 │ create table public.users (id int);
│ ┬───────────────────┯━━━━─────────
│ │ │
│ │ name
│ select range
╰╴
info: function: my_schema.hello
╭▸
3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql;
│ ┬─────────────────────────┯━━━━───────────────────────────────────────────────
│ │ │
│ │ name
╰╴select range
");
}

#[test]
fn empty_file() {
symbols_not_found("")
}

#[test]
fn non_create_statements() {
symbols_not_found("select * from users;")
}
}
1 change: 1 addition & 0 deletions crates/squawk_ide/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod binder;
pub mod code_actions;
pub mod column_name;
pub mod document_symbols;
pub mod expand_selection;
pub mod find_references;
mod generated;
Expand Down
26 changes: 19 additions & 7 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,30 +779,42 @@ pub(crate) fn resolve_insert_table_columns(
}

pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
let table_name_str = extract_table_name_from_path(path)?;
resolve_symbol_info(binder, path, SymbolKind::Table)
}

pub(crate) fn resolve_function_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
resolve_symbol_info(binder, path, SymbolKind::Function)
}

fn resolve_symbol_info(
binder: &Binder,
path: &ast::Path,
kind: SymbolKind,
) -> Option<(Schema, String)> {
let name_str = extract_table_name_from_path(path)?;
let schema = extract_schema_from_path(path);

let table_name_normalized = Name::new(table_name_str.clone());
let symbols = binder.scopes[binder.root_scope()].get(&table_name_normalized)?;
let name_normalized = Name::new(name_str.clone());
let symbols = binder.scopes[binder.root_scope()].get(&name_normalized)?;

if let Some(schema_name) = schema {
let schema_normalized = Schema::new(schema_name);
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && symbol.schema == schema_normalized
symbol.kind == kind && symbol.schema == schema_normalized
})?;
let symbol = &binder.symbols[symbol_id];
return Some((symbol.schema.clone(), table_name_str));
return Some((symbol.schema.clone(), name_str));
} else {
let position = path.syntax().text_range().start();
let search_path = binder.search_path_at(position);
for search_schema in search_path {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && &symbol.schema == search_schema
symbol.kind == kind && &symbol.schema == search_schema
}) {
let symbol = &binder.symbols[symbol_id];
return Some((symbol.schema.clone(), table_name_str));
return Some((symbol.schema.clone(), name_str));
}
}
}
Expand Down
67 changes: 60 additions & 7 deletions crates/squawk_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@ use lsp_types::{
CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams,
CodeActionProviderCapability, CodeActionResponse, Command, Diagnostic,
DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverContents, HoverParams,
HoverProviderCapability, InitializeParams, InlayHint, InlayHintKind, InlayHintLabel,
InlayHintLabelPart, InlayHintParams, LanguageString, Location, MarkedString, OneOf,
PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams,
SelectionRangeProviderCapability, ServerCapabilities, TextDocumentSyncCapability,
DocumentSymbol, DocumentSymbolParams, GotoDefinitionParams, GotoDefinitionResponse, Hover,
HoverContents, HoverParams, HoverProviderCapability, InitializeParams, InlayHint,
InlayHintKind, InlayHintLabel, InlayHintLabelPart, InlayHintParams, LanguageString, Location,
MarkedString, OneOf, PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams,
SelectionRangeProviderCapability, ServerCapabilities, SymbolKind, TextDocumentSyncCapability,
TextDocumentSyncKind, Url, WorkDoneProgressOptions, WorkspaceEdit,
notification::{
DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _,
PublishDiagnostics,
},
request::{
CodeActionRequest, GotoDefinition, HoverRequest, InlayHintRequest, References, Request,
SelectionRangeRequest,
CodeActionRequest, DocumentSymbolRequest, GotoDefinition, HoverRequest, InlayHintRequest,
References, Request, SelectionRangeRequest,
},
};
use rowan::TextRange;
use squawk_ide::code_actions::code_actions;
use squawk_ide::document_symbols::{DocumentSymbolKind, document_symbols};
use squawk_ide::find_references::find_references;
use squawk_ide::goto_definition::goto_definition;
use squawk_ide::hover::hover;
Expand Down Expand Up @@ -67,6 +68,7 @@ pub fn run() -> Result<()> {
definition_provider: Some(OneOf::Left(true)),
hover_provider: Some(HoverProviderCapability::Simple(true)),
inlay_hint_provider: Some(OneOf::Left(true)),
document_symbol_provider: Some(OneOf::Left(true)),
..Default::default()
})
.unwrap();
Expand Down Expand Up @@ -119,6 +121,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
InlayHintRequest::METHOD => {
handle_inlay_hints(&connection, req, &documents)?;
}
DocumentSymbolRequest::METHOD => {
handle_document_symbol(&connection, req, &documents)?;
}
"squawk/syntaxTree" => {
handle_syntax_tree(&connection, req, &documents)?;
}
Expand Down Expand Up @@ -296,6 +301,54 @@ fn handle_inlay_hints(
Ok(())
}

fn handle_document_symbol(
connection: &Connection,
req: lsp_server::Request,
documents: &HashMap<Url, DocumentState>,
) -> Result<()> {
let params: DocumentSymbolParams = serde_json::from_value(req.params)?;
let uri = params.text_document.uri;

let content = documents.get(&uri).map_or("", |doc| &doc.content);
let parse = SourceFile::parse(content);
let file = parse.tree();
let line_index = LineIndex::new(content);

let symbols = document_symbols(&file);

let lsp_symbols: Vec<DocumentSymbol> = symbols
.into_iter()
.map(|sym| {
let range = lsp_utils::range(&line_index, sym.range);
let selection_range = lsp_utils::range(&line_index, sym.selection_range);

DocumentSymbol {
name: sym.name,
detail: sym.detail,
kind: match sym.kind {
DocumentSymbolKind::Table => SymbolKind::STRUCT,
DocumentSymbolKind::Function => SymbolKind::FUNCTION,
},
tags: None,
range,
selection_range,
children: None,
#[allow(deprecated)]
deprecated: None,
}
})
.collect();

let resp = Response {
id: req.id,
result: Some(serde_json::to_value(&lsp_symbols).unwrap()),
error: None,
};

connection.sender.send(Message::Response(resp))?;
Ok(())
}

fn handle_selection_range(
connection: &Connection,
req: lsp_server::Request,
Expand Down
Loading