Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3077,4 +3077,179 @@ delete from users where id in (select id$0 from old_data);
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_table() {
assert_snapshot!(goto("
create table users(id int, email text);
update users$0 set email = 'new@example.com';
"), @r"
╭▸
2 │ create table users(id int, email text);
│ ───── 2. destination
3 │ update users set email = 'new@example.com';
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_table_with_schema() {
assert_snapshot!(goto("
create table public.users(id int, email text);
update public.users$0 set email = 'new@example.com';
"), @r"
╭▸
2 │ create table public.users(id int, email text);
│ ───── 2. destination
3 │ update public.users set email = 'new@example.com';
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_table_with_search_path() {
assert_snapshot!(goto("
set search_path to foo;
create table foo.users(id int, email text);
update users$0 set email = 'new@example.com';
"), @r"
╭▸
3 │ create table foo.users(id int, email text);
│ ───── 2. destination
4 │ update users set email = 'new@example.com';
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_where_column() {
assert_snapshot!(goto("
create table users(id int, email text);
update users set email = 'new@example.com' where id$0 = 1;
"), @r"
╭▸
2 │ create table users(id int, email text);
│ ── 2. destination
3 │ update users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_where_column_with_schema() {
assert_snapshot!(goto("
create table public.users(id int, email text);
update public.users set email = 'new@example.com' where id$0 = 1;
"), @r"
╭▸
2 │ create table public.users(id int, email text);
│ ── 2. destination
3 │ update public.users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_where_column_with_search_path() {
assert_snapshot!(goto("
set search_path to foo;
create table foo.users(id int, email text);
update users set email = 'new@example.com' where id$0 = 1;
"), @r"
╭▸
3 │ create table foo.users(id int, email text);
│ ── 2. destination
4 │ update users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_set_column() {
assert_snapshot!(goto("
create table users(id int, email text);
update users set email$0 = 'new@example.com' where id = 1;
"), @r"
╭▸
2 │ create table users(id int, email text);
│ ───── 2. destination
3 │ update users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_set_column_with_schema() {
assert_snapshot!(goto("
create table public.users(id int, email text);
update public.users set email$0 = 'new@example.com' where id = 1;
"), @r"
╭▸
2 │ create table public.users(id int, email text);
│ ───── 2. destination
3 │ update public.users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_set_column_with_search_path() {
assert_snapshot!(goto("
set search_path to foo;
create table foo.users(id int, email text);
update users set email$0 = 'new@example.com' where id = 1;
"), @r"
╭▸
3 │ create table foo.users(id int, email text);
│ ───── 2. destination
4 │ update users set email = 'new@example.com' where id = 1;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_from_table() {
assert_snapshot!(goto("
create table users(id int, email text);
create table messages(id int, user_id int, email text);
update users set email = messages.email from messages$0 where users.id = messages.user_id;
"), @r"
╭▸
3 │ create table messages(id int, user_id int, email text);
│ ──────── 2. destination
4 │ update users set email = messages.email from messages where users.id = messages.user_id;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_from_table_with_schema() {
assert_snapshot!(goto("
create table users(id int, email text);
create table public.messages(id int, user_id int, email text);
update users set email = messages.email from public.messages$0 where users.id = messages.user_id;
"), @r"
╭▸
3 │ create table public.messages(id int, user_id int, email text);
│ ──────── 2. destination
4 │ update users set email = messages.email from public.messages where users.id = messages.user_id;
╰╴ ─ 1. source
");
}

#[test]
fn goto_update_from_table_with_search_path() {
assert_snapshot!(goto("
set search_path to foo;
create table users(id int, email text);
create table foo.messages(id int, user_id int, email text);
update users set email = messages.email from messages$0 where users.id = messages.user_id;
"), @r"
╭▸
4 │ create table foo.messages(id int, user_id int, email text);
│ ──────── 2. destination
5 │ update users set email = messages.email from messages where users.id = messages.user_id;
╰╴ ─ 1. source
");
}
}
110 changes: 109 additions & 1 deletion crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ enum NameRefContext {
InsertColumn,
DeleteTable,
DeleteWhereColumn,
UpdateTable,
UpdateWhereColumn,
UpdateSetColumn,
UpdateFromTable,
SchemaQualifier,
}

Expand All @@ -42,7 +46,8 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
| NameRefContext::Table
| NameRefContext::CreateIndex
| NameRefContext::InsertTable
| NameRefContext::DeleteTable => {
| NameRefContext::DeleteTable
| NameRefContext::UpdateTable => {
let path = find_containing_path(name_ref)?;
let table_name = extract_table_name(&path)?;
let schema = extract_schema_name(&path);
Expand Down Expand Up @@ -201,6 +206,29 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
NameRefContext::SelectQualifiedColumn => resolve_select_qualified_column(binder, name_ref),
NameRefContext::InsertColumn => resolve_insert_column(binder, name_ref),
NameRefContext::DeleteWhereColumn => resolve_delete_where_column(binder, name_ref),
NameRefContext::UpdateWhereColumn => resolve_update_where_column(binder, name_ref),
NameRefContext::UpdateSetColumn => resolve_update_set_column(binder, name_ref),
NameRefContext::UpdateFromTable => {
let table_name = Name::from_node(name_ref);
let schema = if let Some(parent) = name_ref.syntax().parent()
&& let Some(field_expr) = ast::FieldExpr::cast(parent)
&& let Some(base) = field_expr.base()
&& let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone())
{
Some(Schema(Name::from_node(&schema_name_ref)))
} else {
None
};

if schema.is_none()
&& let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name)
{
return Some(cte_ptr);
}

let position = name_ref.syntax().text_range().start();
resolve_table(binder, &table_name, &schema, position)
}
}
}

Expand All @@ -211,6 +239,7 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
let mut in_column_list = false;
let mut in_where_clause = false;
let mut in_from_clause = false;
let mut in_set_clause = false;

// TODO: can we combine this if and the one that follows?
if let Some(parent) = name_ref.syntax().parent()
Expand Down Expand Up @@ -368,12 +397,27 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
if ast::WhereClause::can_cast(ancestor.kind()) {
in_where_clause = true;
}
if ast::SetClause::can_cast(ancestor.kind()) {
in_set_clause = true;
}
if ast::Delete::can_cast(ancestor.kind()) {
if in_where_clause {
return Some(NameRefContext::DeleteWhereColumn);
}
return Some(NameRefContext::DeleteTable);
}
if ast::Update::can_cast(ancestor.kind()) {
if in_where_clause {
return Some(NameRefContext::UpdateWhereColumn);
}
if in_set_clause {
return Some(NameRefContext::UpdateSetColumn);
}
if in_from_clause {
return Some(NameRefContext::UpdateFromTable);
}
return Some(NameRefContext::UpdateTable);
}
}

None
Expand Down Expand Up @@ -930,6 +974,70 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti
None
}

fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
let column_name = Name::from_node(name_ref);

let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?;
let relation_name = update.relation_name()?;
let path = relation_name.path()?;

let table_name = extract_table_name(&path)?;
let schema = extract_schema_name(&path);
let position = name_ref.syntax().text_range().start();

let table_ptr = resolve_table(binder, &table_name, &schema, position)?;

let root = &name_ref.syntax().ancestors().last()?;
let table_name_node = table_ptr.to_node(root);

let create_table = table_name_node
.ancestors()
.find_map(ast::CreateTable::cast)?;

for arg in create_table.table_arg_list()?.args() {
if let ast::TableArg::Column(column) = arg
&& let Some(col_name) = column.name()
&& Name::from_node(&col_name) == column_name
{
return Some(SyntaxNodePtr::new(col_name.syntax()));
}
}

None
}

fn resolve_update_set_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
let column_name = Name::from_node(name_ref);

let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?;
let relation_name = update.relation_name()?;
let path = relation_name.path()?;

let table_name = extract_table_name(&path)?;
let schema = extract_schema_name(&path);
let position = name_ref.syntax().text_range().start();

let table_ptr = resolve_table(binder, &table_name, &schema, position)?;

let root = &name_ref.syntax().ancestors().last()?;
let table_name_node = table_ptr.to_node(root);

let create_table = table_name_node
.ancestors()
.find_map(ast::CreateTable::cast)?;

for arg in create_table.table_arg_list()?.args() {
if let ast::TableArg::Column(column) = arg
&& let Some(col_name) = column.name()
&& Name::from_node(&col_name) == column_name
{
return Some(SyntaxNodePtr::new(col_name.syntax()));
}
}

None
}

fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
let column_name = Name::from_node(name_ref);

Expand Down
Loading