From f0c59c9095e89c1904585337e6515477c29bb9dd Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 6 May 2025 09:55:18 +0200 Subject: [PATCH 1/2] fix(completions): complete right columns right after JOIN ON --- crates/pgt_completions/src/context.rs | 178 ++++++++++-------- .../pgt_completions/src/providers/columns.rs | 53 ++++++ .../src/relevance/filtering.rs | 27 +-- .../pgt_completions/src/relevance/scoring.rs | 49 ++--- 4 files changed, 197 insertions(+), 110 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 6ace55b6..bf236bd4 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -9,11 +9,13 @@ use pgt_treesitter_queries::{ use crate::sanitization::SanitizedCompletionParams; #[derive(Debug, PartialEq, Eq)] -pub enum ClauseType { +pub enum WrappingClause<'a> { Select, Where, From, - Join, + Join { + on_node: Option>, + }, Update, Delete, } @@ -24,38 +26,6 @@ pub(crate) enum NodeText<'a> { Original(&'a str), } -impl TryFrom<&str> for ClauseType { - type Error = String; - - fn try_from(value: &str) -> Result { - match value { - "select" => Ok(Self::Select), - "where" => Ok(Self::Where), - "from" => Ok(Self::From), - "update" => Ok(Self::Update), - "delete" => Ok(Self::Delete), - "join" => Ok(Self::Join), - _ => { - let message = format!("Unimplemented ClauseType: {}", value); - - // Err on tests, so we notice that we're lacking an implementation immediately. - if cfg!(test) { - panic!("{}", message); - } - - Err(message) - } - } - } -} - -impl TryFrom for ClauseType { - type Error = String; - fn try_from(value: String) -> Result { - Self::try_from(value.as_str()) - } -} - /// We can map a few nodes, such as the "update" node, to actual SQL clauses. /// That gives us a lot of insight for completions. /// Other nodes, such as the "relation" node, gives us less but still @@ -127,7 +97,7 @@ pub(crate) struct CompletionContext<'a> { /// on u.id = i.user_id; /// ``` pub schema_or_alias_name: Option, - pub wrapping_clause_type: Option, + pub wrapping_clause_type: Option>, pub wrapping_node_kind: Option, @@ -266,7 +236,9 @@ impl<'a> CompletionContext<'a> { match parent_node_kind { "statement" | "subquery" => { - self.wrapping_clause_type = current_node_kind.try_into().ok(); + self.wrapping_clause_type = + self.get_wrapping_clause_from_current_node(current_node, &mut cursor); + self.wrapping_statement_range = Some(parent_node.range()); } "invocation" => self.is_invocation = true, @@ -277,39 +249,21 @@ impl<'a> CompletionContext<'a> { if self.is_in_error_node { let mut next_sibling = current_node.next_named_sibling(); while let Some(n) = next_sibling { - if n.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt { - "where" | "update" | "select" | "delete" | "from" | "join" => { - self.wrapping_clause_type = txt.try_into().ok(); - break; - } - _ => {} - } - }; + if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { + self.wrapping_clause_type = Some(clause_type); + break; + } else { + next_sibling = n.next_named_sibling(); } - next_sibling = n.next_named_sibling(); } let mut prev_sibling = current_node.prev_named_sibling(); while let Some(n) = prev_sibling { - if n.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt { - "where" | "update" | "select" | "delete" | "from" | "join" => { - self.wrapping_clause_type = txt.try_into().ok(); - break; - } - _ => {} - } - }; + if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { + self.wrapping_clause_type = Some(clause_type); + break; + } else { + prev_sibling = n.prev_named_sibling(); } - prev_sibling = n.prev_named_sibling(); } } @@ -330,7 +284,8 @@ impl<'a> CompletionContext<'a> { } "where" | "update" | "select" | "delete" | "from" | "join" => { - self.wrapping_clause_type = current_node_kind.try_into().ok(); + self.wrapping_clause_type = + self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } "relation" | "binary_expression" | "assignment" => { @@ -353,12 +308,67 @@ impl<'a> CompletionContext<'a> { cursor.goto_first_child_for_byte(self.position); self.gather_context_from_node(cursor, current_node); } + + fn get_wrapping_clause_from_keyword_node( + &self, + node: tree_sitter::Node<'a>, + ) -> Option> { + if node.kind().starts_with("keyword_") { + if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) { + match txt { + "where" => return Some(WrappingClause::Where), + "update" => return Some(WrappingClause::Update), + "select" => return Some(WrappingClause::Select), + "delete" => return Some(WrappingClause::Delete), + "from" => return Some(WrappingClause::From), + "join" => { + // TODO: not sure if we can infer it here. + return Some(WrappingClause::Join { on_node: None }); + } + _ => {} + } + }; + } + + None + } + + fn get_wrapping_clause_from_current_node( + &self, + node: tree_sitter::Node<'a>, + cursor: &mut tree_sitter::TreeCursor<'a>, + ) -> Option> { + return match node.kind() { + "where" => Some(WrappingClause::Where), + "update" => Some(WrappingClause::Update), + "select" => Some(WrappingClause::Select), + "delete" => Some(WrappingClause::Delete), + "from" => Some(WrappingClause::From), + "join" => { + // sadly, we need to manually iterate over the children – + // `node.child_by_field_id(..)` does not work as expected + let mut on_node = None; + for child in node.children(cursor) { + // 28 is the id for "keyword_on" + if child.kind_id() == 28 { + on_node = Some(child); + } + } + cursor.goto_parent(); + Some(WrappingClause::Join { on_node }) + } + _ => None, + }; + } } #[cfg(test)] mod tests { use crate::{ - context::{ClauseType, CompletionContext, NodeText}, + context::{CompletionContext, NodeText, WrappingClause}, sanitization::SanitizedCompletionParams, test_helper::{CURSOR_POS, get_text_and_position}, }; @@ -375,29 +385,41 @@ mod tests { #[test] fn identifies_clauses() { let test_cases = vec![ - (format!("Select {}* from users;", CURSOR_POS), "select"), - (format!("Select * from u{};", CURSOR_POS), "from"), + ( + format!("Select {}* from users;", CURSOR_POS), + WrappingClause::Select, + ), + ( + format!("Select * from u{};", CURSOR_POS), + WrappingClause::From, + ), ( format!("Select {}* from users where n = 1;", CURSOR_POS), - "select", + WrappingClause::Select, ), ( format!("Select * from users where {}n = 1;", CURSOR_POS), - "where", + WrappingClause::Where, ), ( format!("update users set u{} = 1 where n = 2;", CURSOR_POS), - "update", + WrappingClause::Update, ), ( format!("update users set u = 1 where n{} = 2;", CURSOR_POS), - "where", + WrappingClause::Where, + ), + ( + format!("delete{} from users;", CURSOR_POS), + WrappingClause::Delete, + ), + ( + format!("delete from {}users;", CURSOR_POS), + WrappingClause::From, ), - (format!("delete{} from users;", CURSOR_POS), "delete"), - (format!("delete from {}users;", CURSOR_POS), "from"), ( format!("select name, age, location from public.u{}sers", CURSOR_POS), - "from", + WrappingClause::From, ), ]; @@ -415,7 +437,7 @@ mod tests { let ctx = CompletionContext::new(¶ms); - assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); + assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); } } @@ -518,7 +540,7 @@ mod tests { assert_eq!( ctx.wrapping_clause_type, - Some(crate::context::ClauseType::Select) + Some(crate::context::WrappingClause::Select) ); } } @@ -596,6 +618,6 @@ mod tests { ctx.get_ts_node_content(node), Some(NodeText::Original("fro")) ); - assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select)); + assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); } } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 770a2b61..bd573430 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -431,4 +431,57 @@ mod tests { ) .await; } + + #[tokio::test] + async fn completes_in_join_on_clause() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + assert_complete_results( + format!( + "select u.id, auth.posts.content from auth.users u join auth.posts on u.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column), + CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), + CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), + ], + setup, + ) + .await; + + assert_complete_results( + format!( + "select u.id, p.content from auth.users u join auth.posts p on p.user_id = u.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("uid".to_string(), CompletionItemKind::Column), + CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), + CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 2658216b..c74d8c35 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext, WrappingNode}; +use crate::context::{CompletionContext, WrappingClause}; use super::CompletionRelevanceData; @@ -50,31 +50,36 @@ impl CompletionFilter<'_> { fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { let clause = ctx.wrapping_clause_type.as_ref(); - let wrapping_node = ctx.wrapping_node_kind.as_ref(); match self.data { CompletionRelevanceData::Table(_) => { - let in_select_clause = clause.is_some_and(|c| c == &ClauseType::Select); - let in_where_clause = clause.is_some_and(|c| c == &ClauseType::Where); + let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select); + let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where); if in_select_clause || in_where_clause { return None; }; } CompletionRelevanceData::Column(_) => { - let in_from_clause = clause.is_some_and(|c| c == &ClauseType::From); + let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From); if in_from_clause { return None; } - // We can complete columns in JOIN cluases, but only if we are in the - // "ON u.id = posts.user_id" part. - let in_join_clause = clause.is_some_and(|c| c == &ClauseType::Join); + // We can complete columns in JOIN cluases, but only if we are after the + // ON node in the "ON u.id = posts.user_id" part. + let in_join_clause_before_on_node = clause.is_some_and(|c| match c { + // we are in a JOIN, but definitely not after an ON + WrappingClause::Join { on_node: None } => true, - let in_comparison_clause = - wrapping_node.is_some_and(|n| n == &WrappingNode::BinaryExpression); + WrappingClause::Join { on_node: Some(on) } => ctx + .node_under_cursor + .is_some_and(|n| n.end_byte() < on.start_byte()), - if in_join_clause && !in_comparison_clause { + _ => false, + }); + + if in_join_clause_before_on_node { return None; } } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index e67df658..baff3960 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext, WrappingNode}; +use crate::context::{CompletionContext, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -64,40 +64,47 @@ impl CompletionScore<'_> { let has_mentioned_tables = !ctx.mentioned_relations.is_empty(); let has_mentioned_schema = ctx.schema_or_alias_name.is_some(); - let is_binary_exp = ctx - .wrapping_node_kind - .as_ref() - .is_some_and(|wn| wn == &WrappingNode::BinaryExpression); - self.score += match self.data { CompletionRelevanceData::Table(_) => match clause_type { - ClauseType::Update => 10, - ClauseType::Delete => 10, - ClauseType::From => 5, - ClauseType::Join if !is_binary_exp => 5, + WrappingClause::Update => 10, + WrappingClause::Delete => 10, + WrappingClause::From => 5, + WrappingClause::Join { on_node } + if on_node.is_none_or(|on| { + ctx.node_under_cursor + .is_none_or(|n| n.end_byte() < on.start_byte()) + }) => + { + 5 + } _ => -50, }, CompletionRelevanceData::Function(_) => match clause_type { - ClauseType::Select if !has_mentioned_tables => 15, - ClauseType::Select if has_mentioned_tables => 0, - ClauseType::From => 0, + WrappingClause::Select if !has_mentioned_tables => 15, + WrappingClause::Select if has_mentioned_tables => 0, + WrappingClause::From => 0, _ => -50, }, CompletionRelevanceData::Column(col) => match clause_type { - ClauseType::Select if has_mentioned_tables => 10, - ClauseType::Select if !has_mentioned_tables => 0, - ClauseType::Where => 10, - ClauseType::Join if is_binary_exp => { + WrappingClause::Select if has_mentioned_tables => 10, + WrappingClause::Select if !has_mentioned_tables => 0, + WrappingClause::Where => 10, + WrappingClause::Join { on_node } + if on_node.is_some_and(|on| { + ctx.node_under_cursor + .is_some_and(|n| n.start_byte() > on.end_byte()) + }) => + { // Users will probably join on primary keys if col.is_primary_key { 20 } else { 10 } } _ => -15, }, CompletionRelevanceData::Schema(_) => match clause_type { - ClauseType::From if !has_mentioned_schema => 15, - ClauseType::Join if !has_mentioned_schema => 15, - ClauseType::Update if !has_mentioned_schema => 15, - ClauseType::Delete if !has_mentioned_schema => 15, + WrappingClause::From if !has_mentioned_schema => 15, + WrappingClause::Join { .. } if !has_mentioned_schema => 15, + WrappingClause::Update if !has_mentioned_schema => 15, + WrappingClause::Delete if !has_mentioned_schema => 15, _ => -50, }, } From b9d63651355a17e5b9e17df306be48d66350f671 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 6 May 2025 10:09:02 +0200 Subject: [PATCH 2/2] lint --- crates/pgt_completions/src/context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index bf236bd4..d96d0d53 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -341,7 +341,7 @@ impl<'a> CompletionContext<'a> { node: tree_sitter::Node<'a>, cursor: &mut tree_sitter::TreeCursor<'a>, ) -> Option> { - return match node.kind() { + match node.kind() { "where" => Some(WrappingClause::Where), "update" => Some(WrappingClause::Update), "select" => Some(WrappingClause::Select), @@ -361,7 +361,7 @@ impl<'a> CompletionContext<'a> { Some(WrappingClause::Join { on_node }) } _ => None, - }; + } } }