diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index fec2e2d9..0bb190a9 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -1,7 +1,9 @@ +use std::{ + cmp, + collections::{HashMap, HashSet}, +}; mod policy_parser; -use std::collections::{HashMap, HashSet}; - use pgt_schema_cache::SchemaCache; use pgt_text_size::TextRange; use pgt_treesitter_queries::{ @@ -15,7 +17,7 @@ use crate::{ sanitization::SanitizedCompletionParams, }; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum WrappingClause<'a> { Select, Where, @@ -25,11 +27,15 @@ pub enum WrappingClause<'a> { }, Update, Delete, + ColumnDefinitions, + Insert, + AlterTable, + DropTable, PolicyName, ToRoleAssignment, } -#[derive(PartialEq, Eq, Hash, Debug)] +#[derive(PartialEq, Eq, Hash, Debug, Clone)] pub(crate) struct MentionedColumn { pub(crate) column: String, pub(crate) alias: Option, @@ -48,6 +54,7 @@ pub enum WrappingNode { Relation, BinaryExpression, Assignment, + List, } #[derive(Debug)] @@ -97,6 +104,7 @@ impl TryFrom<&str> for WrappingNode { "relation" => Ok(Self::Relation), "assignment" => Ok(Self::Assignment), "binary_expression" => Ok(Self::BinaryExpression), + "list" => Ok(Self::List), _ => { let message = format!("Unimplemented Relation: {}", value); @@ -118,6 +126,7 @@ impl TryFrom for WrappingNode { } } +#[derive(Debug)] pub(crate) struct CompletionContext<'a> { pub node_under_cursor: Option>, @@ -152,9 +161,6 @@ pub(crate) struct CompletionContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option, - /// Some incomplete statements can't be correctly parsed by TreeSitter. - pub is_in_error_node: bool, - pub mentioned_relations: HashMap, HashSet>, pub mentioned_table_aliases: HashMap, pub mentioned_columns: HashMap>, HashSet>, @@ -176,7 +182,6 @@ impl<'a> CompletionContext<'a> { mentioned_relations: HashMap::new(), mentioned_table_aliases: HashMap::new(), mentioned_columns: HashMap::new(), - is_in_error_node: false, }; // policy handling is important to Supabase, but they are a PostgreSQL specific extension, @@ -231,6 +236,8 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); executor.add_query_results::(); executor.add_query_results::(); + executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -238,37 +245,61 @@ impl<'a> CompletionContext<'a> { let schema_name = r.get_schema(sql); let table_name = r.get_table(sql); - if let Some(c) = self.mentioned_relations.get_mut(&schema_name) { - c.insert(table_name); - } else { - let mut new = HashSet::new(); - new.insert(table_name); - self.mentioned_relations.insert(schema_name, new); - } + self.mentioned_relations + .entry(schema_name) + .and_modify(|s| { + s.insert(table_name.clone()); + }) + .or_insert(HashSet::from([table_name])); } + QueryResult::TableAliases(table_alias_match) => { self.mentioned_table_aliases.insert( table_alias_match.get_alias(sql), table_alias_match.get_table(sql), ); } + QueryResult::SelectClauseColumns(c) => { let mentioned = MentionedColumn { column: c.get_column(sql), alias: c.get_alias(sql), }; - if let Some(cols) = self - .mentioned_columns - .get_mut(&Some(WrappingClause::Select)) - { - cols.insert(mentioned); - } else { - let mut new = HashSet::new(); - new.insert(mentioned); - self.mentioned_columns - .insert(Some(WrappingClause::Select), new); - } + self.mentioned_columns + .entry(Some(WrappingClause::Select)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); + } + + QueryResult::WhereClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: c.get_alias(sql), + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Where)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); + } + + QueryResult::InsertClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); } _ => {} }; @@ -317,10 +348,20 @@ impl<'a> CompletionContext<'a> { * `select * from use {}` becomes `select * from use{}`. */ let current_node = cursor.node(); - while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { - self.position -= 1; + + let mut chars = self.text.chars(); + + if chars + .nth(self.position) + .is_some_and(|c| !c.is_ascii_whitespace() && !&[';', ')'].contains(&c)) + { + self.position = cmp::min(self.position + 1, self.text.len()); + } else { + self.position = cmp::min(self.position, self.text.len()); } + cursor.goto_first_child_for_byte(self.position); + self.gather_context_from_node(cursor, current_node); } @@ -334,8 +375,9 @@ impl<'a> CompletionContext<'a> { let parent_node_kind = parent_node.kind(); let current_node_kind = current_node.kind(); - // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == parent_node_kind { + // prevent infinite recursion – this can happen with ERROR nodes + if current_node_kind == parent_node_kind && ["ERROR", "program"].contains(&parent_node_kind) + { self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } @@ -352,25 +394,17 @@ impl<'a> CompletionContext<'a> { } // try to gather context from the siblings if we're within an error node. - if self.is_in_error_node { - let mut next_sibling = current_node.next_named_sibling(); - while let Some(n) = next_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - next_sibling = n.next_named_sibling(); - } + if parent_node_kind == "ERROR" { + if let Some(clause_type) = self.get_wrapping_clause_from_error_node_child(current_node) + { + self.wrapping_clause_type = Some(clause_type); } - let mut prev_sibling = current_node.prev_named_sibling(); - while let Some(n) = prev_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - prev_sibling = n.prev_named_sibling(); - } + if let Some(wrapping_node) = self.get_wrapping_node_from_error_node_child(current_node) + { + self.wrapping_node_kind = Some(wrapping_node) } + + self.get_info_from_error_node_child(current_node); } match current_node_kind { @@ -389,7 +423,8 @@ impl<'a> CompletionContext<'a> { } } - "where" | "update" | "select" | "delete" | "from" | "join" => { + "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" + | "drop_table" | "alter_table" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -398,8 +433,13 @@ impl<'a> CompletionContext<'a> { self.wrapping_node_kind = current_node_kind.try_into().ok(); } - "ERROR" => { - self.is_in_error_node = true; + "list" => { + if current_node + .prev_sibling() + .is_none_or(|n| n.kind() != "keyword_values") + { + self.wrapping_node_kind = current_node_kind.try_into().ok(); + } } _ => {} @@ -415,31 +455,165 @@ impl<'a> CompletionContext<'a> { self.gather_context_from_node(cursor, current_node); } - fn get_wrapping_clause_from_keyword_node( + fn get_first_sibling(&self, node: tree_sitter::Node<'a>) -> tree_sitter::Node<'a> { + let mut first_sibling = node; + while let Some(n) = first_sibling.prev_sibling() { + first_sibling = n; + } + first_sibling + } + + fn get_wrapping_node_from_error_node_child( + &self, + node: tree_sitter::Node<'a>, + ) -> Option { + self.wrapping_clause_type + .as_ref() + .and_then(|clause| match clause { + WrappingClause::Insert => { + let mut first_sib = self.get_first_sibling(node); + + let mut after_opening_bracket = false; + let mut before_closing_bracket = false; + + while let Some(next_sib) = first_sib.next_sibling() { + if next_sib.kind() == "(" + && next_sib.end_position() <= node.start_position() + { + after_opening_bracket = true; + } + + if next_sib.kind() == ")" + && next_sib.start_position() >= node.end_position() + { + before_closing_bracket = true; + } + + first_sib = next_sib; + } + + if after_opening_bracket && before_closing_bracket { + Some(WrappingNode::List) + } else { + None + } + } + _ => None, + }) + } + + fn get_wrapping_clause_from_error_node_child( &self, node: tree_sitter::Node<'a>, ) -> Option> { - if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt.as_str() { - "where" => return Some(WrappingClause::Where), - "update" => return Some(WrappingClause::Update), - "select" => return Some(WrappingClause::Select), - "delete" => return Some(WrappingClause::Delete), - "from" => return Some(WrappingClause::From), - "join" => { - // TODO: not sure if we can infer it here. - return Some(WrappingClause::Join { on_node: None }); + let clause_combinations: Vec<(WrappingClause, &[&'static str])> = vec![ + (WrappingClause::Where, &["where"]), + (WrappingClause::Update, &["update"]), + (WrappingClause::Select, &["select"]), + (WrappingClause::Delete, &["delete"]), + (WrappingClause::Insert, &["insert", "into"]), + (WrappingClause::From, &["from"]), + (WrappingClause::Join { on_node: None }, &["join"]), + (WrappingClause::AlterTable, &["alter", "table"]), + ( + WrappingClause::AlterTable, + &["alter", "table", "if", "exists"], + ), + (WrappingClause::DropTable, &["drop", "table"]), + ( + WrappingClause::DropTable, + &["drop", "table", "if", "exists"], + ), + ]; + + let first_sibling = self.get_first_sibling(node); + + /* + * For each clause, we'll iterate from first_sibling to the next ones, + * either until the end or until we land on the node under the cursor. + * We'll score the `WrappingClause` by how many tokens it matches in order. + */ + let mut clauses_with_score: Vec<(WrappingClause, usize)> = clause_combinations + .into_iter() + .map(|(clause, tokens)| { + let mut idx = 0; + + let mut sibling = Some(first_sibling); + while let Some(sib) = sibling { + if sib.end_byte() >= node.end_byte() || idx >= tokens.len() { + break; + } + + if let Some(sibling_content) = + self.get_ts_node_content(&sib).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) + { + if sibling_content == tokens[idx] { + idx += 1; + } + } else { + break; } - _ => {} + + sibling = sib.next_sibling(); } - }; - } - None + (clause, idx) + }) + .collect(); + + clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a)); + clauses_with_score + .iter() + .find(|(_, score)| *score > 0) + .map(|c| c.0.clone()) + } + + fn get_info_from_error_node_child(&mut self, node: tree_sitter::Node<'a>) { + let mut first_sibling = self.get_first_sibling(node); + + if let Some(clause) = self.wrapping_clause_type.as_ref() { + if clause == &WrappingClause::Insert { + while let Some(sib) = first_sibling.next_sibling() { + match sib.kind() { + "object_reference" => { + if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { + let mut iter = txt.split('.').rev(); + let table = iter.next().unwrap().to_string(); + let schema = iter.next().map(|s| s.to_string()); + self.mentioned_relations + .entry(schema) + .and_modify(|s| { + s.insert(table.clone()); + }) + .or_insert(HashSet::from([table])); + } + } + "column" => { + if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { + let entry = MentionedColumn { + column: txt, + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(entry.clone()); + }) + .or_insert(HashSet::from([entry])); + } + } + + _ => {} + } + + first_sibling = sib; + } + } + } } fn get_wrapping_clause_from_current_node( @@ -453,6 +627,10 @@ impl<'a> CompletionContext<'a> { "select" => Some(WrappingClause::Select), "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), + "drop_table" => Some(WrappingClause::DropTable), + "alter_table" => Some(WrappingClause::AlterTable), + "column_definitions" => Some(WrappingClause::ColumnDefinitions), + "insert" => Some(WrappingClause::Insert), "join" => { // sadly, we need to manually iterate over the children – // `node.child_by_field_id(..)` does not work as expected @@ -469,6 +647,38 @@ impl<'a> CompletionContext<'a> { _ => None, } } + + pub(crate) fn parent_matches_one_of_kind(&self, kinds: &[&'static str]) -> bool { + self.node_under_cursor + .as_ref() + .is_some_and(|under_cursor| match under_cursor { + NodeUnderCursor::TsNode(node) => node + .parent() + .is_some_and(|parent| kinds.contains(&parent.kind())), + + NodeUnderCursor::CustomNode { .. } => false, + }) + } + pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { + self.node_under_cursor.as_ref().is_some_and(|under_cursor| { + match under_cursor { + NodeUnderCursor::TsNode(node) => { + let mut current = *node; + + // move up to the parent until we're at top OR we have a prev sibling + while current.prev_sibling().is_none() && current.parent().is_some() { + current = current.parent().unwrap(); + } + + current + .prev_sibling() + .is_some_and(|sib| kinds.contains(&sib.kind())) + } + + NodeUnderCursor::CustomNode { .. } => false, + } + }) + } } #[cfg(test)] diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 8109ba83..a040bab1 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -23,7 +23,12 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio }; // autocomplete with the alias in a join clause if we find one - if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) { + if matches!( + ctx.wrapping_clause_type, + Some(WrappingClause::Join { .. }) + | Some(WrappingClause::Where) + | Some(WrappingClause::Select) + ) { item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str()) .and_then(|alias| { get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str()) @@ -36,11 +41,13 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio #[cfg(test)] mod tests { + use std::vec; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ - CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps, - get_test_params, + CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, + assert_no_complete_results, get_test_deps, get_test_params, }, }; @@ -573,4 +580,151 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_columns_in_insert_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text + ); + + create table others ( + id serial primary key, + a text, + b text + ); + "#; + + // We should prefer the instrument columns, even though they + // are lower in the alphabet + + assert_complete_results( + format!("insert into instruments ({})", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("insert into instruments (id, {})", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("insert into instruments (id, {}, name)", CURSOR_POS).as_str(), + vec![CompletionAssertion::Label("z".to_string())], + setup, + ) + .await; + + // works with completed statement + assert_complete_results( + format!( + "insert into instruments (name, {}) values ('my_bass');", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + + // no completions in the values list! + assert_no_complete_results( + format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), + setup, + ) + .await; + } + + #[tokio::test] + async fn suggests_columns_in_where_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text, + created_at timestamp with time zone default now() + ); + + create table others ( + a text, + b text, + c text + ); + "#; + + assert_complete_results( + format!("select name from instruments where {} ", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!( + "select name from instruments where z = 'something' and created_at > {}", + CURSOR_POS + ) + .as_str(), + // simply do not complete columns + schemas; functions etc. are ok + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Column), + CompletionAssertion::KindNotExists(CompletionItemKind::Schema), + ], + setup, + ) + .await; + + // prefers not mentioned columns + assert_complete_results( + format!( + "select name from instruments where id = 'something' and {}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + // // uses aliases + assert_complete_results( + format!( + "select name from instruments i join others o on i.z = o.a where i.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 57195da7..96d327de 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -310,4 +310,123 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_tables_in_alter_and_drop_statements() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + assert_complete_results( + format!("alter table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("alter table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } + + #[tokio::test] + async fn suggests_tables_in_insert_into() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + "#; + + assert_complete_results( + format!("insert into {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("insert into auth.{}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "users".into(), + CompletionItemKind::Table, + )], + setup, + ) + .await; + + // works with complete statement. + assert_complete_results( + format!( + "insert into {} (name, email) values ('jules', 'a@b.com');", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 3b148336..5323e2bc 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause}; +use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -24,6 +24,10 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { + if ctx.wrapping_node_kind.is_none() && ctx.wrapping_clause_type.is_none() { + return None; + } + let current_node_kind = ctx .node_under_cursor .as_ref() @@ -65,55 +69,109 @@ impl CompletionFilter<'_> { } fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { - let clause = ctx.wrapping_clause_type.as_ref(); + ctx.wrapping_clause_type + .as_ref() + .map(|clause| { + match self.data { + CompletionRelevanceData::Table(_) => match clause { + WrappingClause::Select + | WrappingClause::Where + | WrappingClause::ColumnDefinitions => false, - let in_clause = |compare: WrappingClause| clause.is_some_and(|c| c == &compare); + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_none_or(|n| n != &WrappingNode::List) + && (ctx.before_cursor_matches_kind(&["keyword_into"]) + || (ctx.before_cursor_matches_kind(&["."]) + && ctx.parent_matches_one_of_kind(&["object_reference"]))) + } - match self.data { - CompletionRelevanceData::Table(_) => { - if in_clause(WrappingClause::Select) - || in_clause(WrappingClause::Where) - || in_clause(WrappingClause::PolicyName) - { - return None; - }; - } - CompletionRelevanceData::Column(_) => { - if in_clause(WrappingClause::From) || in_clause(WrappingClause::PolicyName) { - return None; - } + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), - // We can complete columns in JOIN cluases, but only if we are after the - // ON node in the "ON u.id = posts.user_id" part. - let in_join_clause_before_on_node = clause.is_some_and(|c| match c { - // we are in a JOIN, but definitely not after an ON - WrappingClause::Join { on_node: None } => true, + _ => true, + }, - WrappingClause::Join { on_node: Some(on) } => ctx - .node_under_cursor - .as_ref() - .is_some_and(|n| n.end_byte() < on.start_byte()), + CompletionRelevanceData::Column(_) => { + match clause { + WrappingClause::From + | WrappingClause::ColumnDefinitions + | WrappingClause::AlterTable + | WrappingClause::DropTable => false, - _ => false, - }); + // We can complete columns in JOIN cluases, but only if we are after the + // ON node in the "ON u.id = posts.user_id" part. + WrappingClause::Join { on_node: Some(on) } => ctx + .node_under_cursor + .as_ref() + .is_some_and(|cn| cn.start_byte() >= on.end_byte()), - if in_join_clause_before_on_node { - return None; - } - } - CompletionRelevanceData::Policy(_) => { - if clause.is_none_or(|c| c != &WrappingClause::PolicyName) { - return None; - } - } - _ => { - if in_clause(WrappingClause::PolicyName) { - return None; - } - } - } + // we are in a JOIN, but definitely not after an ON + WrappingClause::Join { on_node: None } => false, - Some(()) + WrappingClause::Insert => ctx + .wrapping_node_kind + .as_ref() + .is_some_and(|n| n == &WrappingNode::List), + + // only autocomplete left side of binary expression + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + || (ctx.before_cursor_matches_kind(&["."]) + && ctx.parent_matches_one_of_kind(&["field"])) + } + + _ => true, + } + } + + CompletionRelevanceData::Function(_) => matches!( + clause, + WrappingClause::From + | WrappingClause::Select + | WrappingClause::Where + | WrappingClause::Join { .. } + ), + + CompletionRelevanceData::Schema(_) => match clause { + WrappingClause::Select + | WrappingClause::From + | WrappingClause::Join { .. } + | WrappingClause::Update + | WrappingClause::Delete => true, + + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + } + + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), + + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_none_or(|n| n != &WrappingNode::List) + && ctx.before_cursor_matches_kind(&["keyword_into"]) + } + + _ => false, + }, + + CompletionRelevanceData::Policy(_) => { + matches!(clause, WrappingClause::PolicyName) + } + } + }) + .and_then(|is_ok| if is_ok { Some(()) } else { None }) } fn check_invocation(&self, ctx: &CompletionContext) -> Option<()> { @@ -188,4 +246,15 @@ mod tests { ) .await; } + + #[tokio::test] + async fn completion_after_create_table() { + assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), "").await; + } + + #[tokio::test] + async fn completion_in_column_definitions() { + let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); + assert_no_complete_results(query.as_str(), "").await; + } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 6aa75a16..40dea7e6 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -53,6 +53,7 @@ where || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) + || cursor_between_parentheses(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -192,24 +193,81 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool return false; } - // not okay to be on the semi. - if byte == leaf_node.start_byte() { - return false; - } - leaf_node .prev_named_sibling() .map(|n| n.end_byte() < byte) .unwrap_or(false) } +fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + + let mut level = 0; + let mut tracking_open_idx = None; + + let mut matching_open_idx = None; + let mut matching_close_idx = None; + + for (idx, char) in sql.chars().enumerate() { + if char == '(' { + tracking_open_idx = Some(idx); + level += 1; + } + + if char == ')' { + level -= 1; + + if tracking_open_idx.is_some_and(|it| it < position) && idx >= position { + matching_open_idx = tracking_open_idx; + matching_close_idx = Some(idx) + } + } + } + + // invalid statement + if level != 0 { + return false; + } + + // early check: '(|)' + // however, we want to check this after the level nesting. + let mut chars = sql.chars(); + if chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')') { + return true; + } + + // not *within* parentheses + if matching_open_idx.is_none() || matching_close_idx.is_none() { + return false; + } + + // use string indexing, because we can't `.rev()` after `.take()` + let before = sql[..position] + .to_string() + .chars() + .rev() + .find(|c| !c.is_whitespace()) + .unwrap_or_default(); + + let after = sql + .chars() + .skip(position) + .find(|c| !c.is_whitespace()) + .unwrap_or_default(); + + let before_matches = before == ',' || before == '('; + let after_matches = after == ',' || after == ')'; + + before_matches && after_matches +} + #[cfg(test)] mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot, - cursor_prepared_to_write_token_after_last_node, + cursor_before_semicolon, cursor_between_parentheses, cursor_inbetween_nodes, + cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, }; #[test] @@ -292,18 +350,67 @@ mod tests { // select * from| ; <-- still touches the from assert!(!cursor_before_semicolon(&tree, TextSize::new(13))); - // not okay to be ON the semi. - // select * from |; - assert!(!cursor_before_semicolon(&tree, TextSize::new(18))); - // anything is fine here - // select * from | ; - // select * from | ; - // select * from | ; - // select * from |; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from |; assert!(cursor_before_semicolon(&tree, TextSize::new(14))); assert!(cursor_before_semicolon(&tree, TextSize::new(15))); assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); + assert!(cursor_before_semicolon(&tree, TextSize::new(18))); + } + + #[test] + fn between_parentheses() { + let input = "insert into instruments ()"; + + // insert into (|) <- right in the parentheses + assert!(cursor_between_parentheses(input, TextSize::new(25))); + + // insert into ()| <- too late + assert!(!cursor_between_parentheses(input, TextSize::new(26))); + + // insert into |() <- too early + assert!(!cursor_between_parentheses(input, TextSize::new(24))); + + let input = "insert into instruments (name, id, )"; + + // insert into instruments (name, id, |) <-- we should sanitize the next column + assert!(cursor_between_parentheses(input, TextSize::new(35))); + + // insert into instruments (name, id|, ) <-- we are still on the previous token. + assert!(!cursor_between_parentheses(input, TextSize::new(33))); + + let input = "insert into instruments (name, , id)"; + + // insert into instruments (name, |, id) <-- we can sanitize! + assert!(cursor_between_parentheses(input, TextSize::new(31))); + + // insert into instruments (name, ,| id) <-- we are already on the next token + assert!(!cursor_between_parentheses(input, TextSize::new(32))); + + let input = "insert into instruments (, name, id)"; + + // insert into instruments (|, name, id) <-- we can sanitize! + assert!(cursor_between_parentheses(input, TextSize::new(25))); + + // insert into instruments (,| name, id) <-- already on next token + assert!(!cursor_between_parentheses(input, TextSize::new(26))); + + // bails on invalidly nested statements + assert!(!cursor_between_parentheses( + "insert into (instruments ()", + TextSize::new(26) + )); + + // can find its position in nested statements + // "insert into instruments (name) values (a_function(name, |))", + assert!(cursor_between_parentheses( + "insert into instruments (name) values (a_function(name, ))", + TextSize::new(56) + )); } } diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index b3e35b69..acfc60ed 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -37,7 +37,7 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa // The request is used to get more information about a simple CompletionItem. resolve_provider: None, - trigger_characters: Some(vec![".".to_owned(), " ".to_owned()]), + trigger_characters: Some(vec![".".to_owned(), " ".to_owned(), "(".to_owned()]), // No character will lead to automatically inserting the selected completion-item all_commit_characters: None, diff --git a/crates/pgt_treesitter_queries/src/queries/insert_columns.rs b/crates/pgt_treesitter_queries/src/queries/insert_columns.rs new file mode 100644 index 00000000..3e88d998 --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/insert_columns.rs @@ -0,0 +1,150 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (insert + (object_reference) + (list + "("? + (column) @column + ","? + ")"? + ) + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct InsertColumnMatch<'a> { + pub(crate) column: tree_sitter::Node<'a>, +} + +impl InsertColumnMatch<'_> { + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a InsertColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::InsertClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for InsertColumnMatch<'a> { + type Ref = &'a InsertColumnMatch<'a>; +} + +impl<'a> Query<'a> for InsertColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::InsertClauseColumns(InsertColumnMatch { + column: capture, + })); + } + } + + to_return + } +} +#[cfg(test)] +mod tests { + use super::InsertColumnMatch; + use crate::TreeSitterQueriesExecutor; + + #[test] + fn finds_all_insert_columns() { + let sql = r#"insert into users (id, email, name) values (1, 'a@b.com', 'Alice');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + let columns: Vec = results.iter().map(|c| c.get_column(sql)).collect(); + + assert_eq!(columns, vec!["id", "email", "name"]); + } + + #[test] + fn finds_insert_columns_with_whitespace_and_commas() { + let sql = r#" + insert into users ( + id, + email, + name + ) values (1, 'a@b.com', 'Alice'); + "#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + let columns: Vec = results.iter().map(|c| c.get_column(sql)).collect(); + + assert_eq!(columns, vec!["id", "email", "name"]); + } + + #[test] + fn returns_empty_for_insert_without_columns() { + let sql = r#"insert into users values (1, 'a@b.com', 'Alice');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert!(results.is_empty()); + } +} diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index aec6ce1a..b9f39aed 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,12 +1,16 @@ +mod insert_columns; mod parameters; mod relations; mod select_columns; mod table_aliases; +mod where_columns; +pub use insert_columns::*; pub use parameters::*; pub use relations::*; pub use select_columns::*; pub use table_aliases::*; +pub use where_columns::*; #[derive(Debug)] pub enum QueryResult<'a> { @@ -14,6 +18,8 @@ pub enum QueryResult<'a> { Parameter(ParameterMatch<'a>), TableAliases(TableAliasMatch<'a>), SelectClauseColumns(SelectColumnMatch<'a>), + InsertClauseColumns(InsertColumnMatch<'a>), + WhereClauseColumns(WhereColumnMatch<'a>), } impl QueryResult<'_> { @@ -50,6 +56,21 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + Self::WhereClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } + Self::InsertClauseColumns(cm) => { + let start = cm.column.start_position(); + let end = cm.column.end_position(); + start >= range.start_point && end <= range.end_point + } } } } diff --git a/crates/pgt_treesitter_queries/src/queries/relations.rs b/crates/pgt_treesitter_queries/src/queries/relations.rs index f9061ce8..38fd0513 100644 --- a/crates/pgt_treesitter_queries/src/queries/relations.rs +++ b/crates/pgt_treesitter_queries/src/queries/relations.rs @@ -14,6 +14,14 @@ static TS_QUERY: LazyLock = LazyLock::new(|| { (identifier)? @table )+ ) + (insert + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + )+ + ) "#; tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") }); @@ -91,3 +99,101 @@ impl<'a> Query<'a> for RelationMatch<'a> { to_return } } + +#[cfg(test)] +mod tests { + use super::RelationMatch; + use crate::TreeSitterQueriesExecutor; + + #[test] + fn finds_table_without_schema() { + let sql = r#"select * from users;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), None); + assert_eq!(results[0].get_table(sql), "users"); + } + + #[test] + fn finds_table_with_schema() { + let sql = r#"select * from public.users;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("public".to_string())); + assert_eq!(results[0].get_table(sql), "users"); + } + + #[test] + fn finds_insert_into_with_schema_and_table() { + let sql = r#"insert into auth.accounts (id, email) values (1, 'a@b.com');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("auth".to_string())); + assert_eq!(results[0].get_table(sql), "accounts"); + } + + #[test] + fn finds_insert_into_without_schema() { + let sql = r#"insert into users (id, email) values (1, 'a@b.com');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), None); + assert_eq!(results[0].get_table(sql), "users"); + } +} diff --git a/crates/pgt_treesitter_queries/src/queries/where_columns.rs b/crates/pgt_treesitter_queries/src/queries/where_columns.rs new file mode 100644 index 00000000..8e19590d --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/where_columns.rs @@ -0,0 +1,96 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (where + (binary_expression + (binary_expression + (field + (object_reference)? @alias + "."? + (identifier) @column + ) + ) + ) + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct WhereColumnMatch<'a> { + pub(crate) alias: Option>, + pub(crate) column: tree_sitter::Node<'a>, +} + +impl WhereColumnMatch<'_> { + pub fn get_alias(&self, sql: &str) -> Option { + let str = self + .alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch"); + + Some(str.to_string()) + } + + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a WhereColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::WhereClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> { + type Ref = &'a WhereColumnMatch<'a>; +} + +impl<'a> Query<'a> for WhereColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: None, + column: capture, + })); + } + + if m.captures.len() == 2 { + let alias = m.captures[0].node; + let column = m.captures[1].node; + + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: Some(alias), + column, + })); + } + } + + to_return + } +} diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 48f91ef4..bc2c6c3b 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -15,6 +15,7 @@ pub struct SQLFunctionArg { #[derive(Debug, Clone)] pub struct SQLFunctionSignature { + #[allow(dead_code)] pub schema: Option, pub name: String, pub args: Vec,