From 5c8e8e8b314476c7032039b49ecdd04096ce1c11 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 4 Apr 2025 16:33:30 +0200 Subject: [PATCH 01/11] =?UTF-8?q?so=20far=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + crates/pgt_completions/Cargo.toml | 1 + crates/pgt_completions/src/complete.rs | 1 + crates/pgt_completions/src/context.rs | 2 + .../pgt_completions/src/providers/columns.rs | 82 ++++++++++++++++++- .../pgt_completions/src/providers/tables.rs | 4 +- crates/pgt_completions/src/relevance.rs | 22 ++--- postgrestools.jsonc | 2 +- 8 files changed, 97 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 79ec52f0..30d59465 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2452,6 +2452,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", + "tracing", "tree-sitter", "tree_sitter_sql", ] diff --git a/crates/pgt_completions/Cargo.toml b/crates/pgt_completions/Cargo.toml index dba88f41..559639f3 100644 --- a/crates/pgt_completions/Cargo.toml +++ b/crates/pgt_completions/Cargo.toml @@ -22,6 +22,7 @@ pgt_treesitter_queries.workspace = true schemars = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tracing = { workspace = true } tree-sitter.workspace = true tree_sitter_sql.workspace = true diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index fb00aeaf..7813ca52 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -17,6 +17,7 @@ pub struct CompletionParams<'a> { pub tree: Option<&'a tree_sitter::Tree>, } +#[tracing::instrument(level = "debug")] pub fn complete(params: CompletionParams) -> Vec { let ctx = CompletionContext::new(¶ms); diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 8b12742d..227a9ba9 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -81,6 +81,8 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); + println!("Here's my node: {:?}", ctx.ts_node.unwrap()); + ctx } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 3f1c5bb9..2898b63f 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -143,12 +143,86 @@ mod tests { let params = get_test_params(&tree, &cache, case.get_input_query()); let mut items = complete(params); - let _ = items.split_off(3); + let _ = items.split_off(6); - items.sort_by(|a, b| a.label.cmp(&b.label)); + #[derive(Eq, PartialEq, Debug)] + struct LabelAndDesc { + label: String, + desc: String, + } + + let labels: Vec = items + .into_iter() + .map(|c| LabelAndDesc { + label: c.label, + desc: c.description, + }) + .collect(); + + let expected = vec![ + ("name", "Table: public.users"), + ("narrator", "Table: public.audio_books"), + ("narrator_id", "Table: private.audio_books"), + ("name", "Schema: pg_catalog"), + ("nameconcatoid", "Schema: pg_catalog"), + ("nameeq", "Schema: pg_catalog"), + ] + .into_iter() + .map(|(label, schema)| LabelAndDesc { + label: label.into(), + desc: schema.into(), + }) + .collect::>(); + + assert_eq!(labels, expected); + } + + #[tokio::test] + async fn suggests_relevant_columns_without_letters() { + let setup = r#" + create table users ( + id serial primary key, + name text, + address text, + email text + ); + "#; + + let test_case = TestCase { + message: "suggests user created tables first", + query: format!(r#"select {} from users"#, CURSOR_POS), + label: "", + description: "", + }; + + let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let params = get_test_params(&tree, &cache, test_case.get_input_query()); + let results = complete(params); - let labels: Vec = items.into_iter().map(|c| c.label).collect(); + let (first_four, _rest) = results.split_at(4); + + let has_column_in_first_four = |col: &'static str| { + first_four + .iter() + .find(|compl_item| compl_item.label.as_str() == col) + .is_some() + }; - assert_eq!(labels, vec!["name", "narrator", "narrator_id"]); + assert!( + has_column_in_first_four("id"), + "`id` not present in first four completion items." + ); + assert!( + has_column_in_first_four("name"), + "`name` not present in first four completion items." + ); + assert!( + has_column_in_first_four("address"), + "`address` not present in first four completion items." + ); + assert!( + has_column_in_first_four("email"), + "`email` not present in first four completion items." + ); } } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 6a1e00c9..18fce14b 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -73,8 +73,8 @@ mod tests { "#; let test_cases = vec![ - (format!("select * from us{}", CURSOR_POS), "users"), - (format!("select * from em{}", CURSOR_POS), "emails"), + // (format!("select * from us{}", CURSOR_POS), "users"), + // (format!("select * from em{}", CURSOR_POS), "emails"), (format!("select * from {}", CURSOR_POS), "addresses"), ]; diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index ffe6cb22..af337357 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -33,7 +33,6 @@ impl CompletionRelevance<'_> { self.check_is_user_defined(); self.check_matches_schema(ctx); self.check_matches_query_input(ctx); - self.check_if_catalog(ctx); self.check_is_invocation(ctx); self.check_matching_clause_type(ctx); self.check_relations_in_stmt(ctx); @@ -52,7 +51,10 @@ impl CompletionRelevance<'_> { let name = match self.data { CompletionRelevanceData::Function(f) => f.name.as_str(), CompletionRelevanceData::Table(t) => t.name.as_str(), - CompletionRelevanceData::Column(c) => c.name.as_str(), + CompletionRelevanceData::Column(c) => { + // + c.name.as_str() + } }; if name.starts_with(content) { @@ -61,7 +63,7 @@ impl CompletionRelevance<'_> { .try_into() .expect("The length of the input exceeds i32 capacity"); - self.score += len * 5; + self.score += len * 10; }; } @@ -135,14 +137,6 @@ impl CompletionRelevance<'_> { } } - fn check_if_catalog(&mut self, ctx: &CompletionContext) { - if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { - return; - } - - self.score -= 5; // unlikely that the user wants schema data - } - fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { match self.data { CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return, @@ -182,5 +176,11 @@ impl CompletionRelevance<'_> { if system_schemas.contains(&schema.as_str()) { self.score -= 10; } + + // "public" is the default postgres schema where users + // create objects. Prefer it by a slight bit. + if schema.as_str() == "public" { + self.score += 2; + } } } diff --git a/postgrestools.jsonc b/postgrestools.jsonc index 325c7861..0ce2e44f 100644 --- a/postgrestools.jsonc +++ b/postgrestools.jsonc @@ -17,7 +17,7 @@ // YOU CAN COMMENT ME OUT :) "db": { "host": "127.0.0.1", - "port": 5432, + "port": 54322, "username": "postgres", "password": "postgres", "database": "postgres", From 6c4812b3ff23832affd45a0ff968e847b54039dc Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 4 Apr 2025 17:43:11 +0200 Subject: [PATCH 02/11] hmmm --- crates/pgt_completions/src/complete.rs | 5 +- crates/pgt_completions/src/context.rs | 65 +++++++++++++++----- crates/pgt_completions/src/relevance.rs | 5 +- crates/pgt_workspace/src/workspace/server.rs | 20 +++--- 4 files changed, 72 insertions(+), 23 deletions(-) diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 7813ca52..5505daa2 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -17,7 +17,10 @@ pub struct CompletionParams<'a> { pub tree: Option<&'a tree_sitter::Tree>, } -#[tracing::instrument(level = "debug")] +#[tracing::instrument(level = "debug", skip_all, fields( + text = params.text, + position = params.position.to_string() +))] pub fn complete(params: CompletionParams) -> Vec { let ctx = CompletionContext::new(¶ms); diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 227a9ba9..776e97d1 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -49,12 +49,29 @@ impl TryFrom for ClauseType { } pub(crate) struct CompletionContext<'a> { - pub ts_node: Option>, + pub node_under_cursor: Option>, + pub previous_node: Option>, + pub tree: Option<&'a tree_sitter::Tree>, pub text: &'a str, pub schema_cache: &'a SchemaCache, pub position: usize, + /// If the cursor of the user is offset to the right of the statement, + /// we'll have to move it back to the last node, otherwise, tree-sitter will break. + /// However, knowing that the user is typing on the "next" node lets us prioritize different completion results. + /// We consider an offset of up to two characters as valid. + /// + /// Example: + /// + /// ``` + /// select * from {} + /// ``` + /// + /// We'll adjust the cursor position so it lies on the "from" token – but we're looking + /// for table completions. + pub cursor_offset_from_end: bool, + pub schema_name: Option, pub wrapping_clause_type: Option, pub is_invocation: bool, @@ -70,7 +87,9 @@ impl<'a> CompletionContext<'a> { text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), - ts_node: None, + cursor_offset_from_end: false, + previous_node: None, + node_under_cursor: None, schema_name: None, wrapping_clause_type: None, wrapping_statement_range: None, @@ -81,8 +100,6 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); - println!("Here's my node: {:?}", ctx.ts_node.unwrap()); - ctx } @@ -147,30 +164,34 @@ impl<'a> CompletionContext<'a> { * `select * from use {}` becomes `select * from use{}`. */ let current_node = cursor.node(); + let position_cache = self.position.clone(); while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { self.position -= 1; } + let cursor_offset = position_cache - self.position; + self.cursor_offset_from_end = cursor_offset > 0 && cursor_offset <= 2; + self.gather_context_from_node(cursor, current_node); } fn gather_context_from_node( &mut self, mut cursor: tree_sitter::TreeCursor<'a>, - previous_node: tree_sitter::Node<'a>, + parent_node: tree_sitter::Node<'a>, ) { let current_node = cursor.node(); // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node.kind() == previous_node.kind() { - self.ts_node = Some(current_node); + if current_node.kind() == parent_node.kind() { + self.node_under_cursor = Some(current_node); return; } - match previous_node.kind() { + match parent_node.kind() { "statement" | "subquery" => { self.wrapping_clause_type = current_node.kind().try_into().ok(); - self.wrapping_statement_range = Some(previous_node.range()); + self.wrapping_statement_range = Some(parent_node.range()); } "invocation" => self.is_invocation = true, @@ -202,7 +223,23 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - self.ts_node = Some(current_node); + if self.cursor_offset_from_end { + self.node_under_cursor = None; + self.previous_node = Some(current_node); + } else { + // for the previous node, either select the previous sibling, + // or collect the parent's previous sibling's last child. + let previous = match current_node.prev_sibling() { + Some(n) => Some(n), + None => { + let sib_of_parent = parent_node.prev_sibling(); + sib_of_parent.and_then(|p| p.children(&mut cursor).last()) + } + }; + self.node_under_cursor = Some(current_node); + self.previous_node = previous; + } + return; } @@ -361,7 +398,7 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); assert_eq!(ctx.get_ts_node_content(node), Some("select")); @@ -389,7 +426,7 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); assert_eq!(ctx.get_ts_node_content(node), Some("from")); assert_eq!( @@ -415,7 +452,7 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); assert_eq!(ctx.get_ts_node_content(node), Some("")); assert_eq!(ctx.wrapping_clause_type, None); @@ -440,7 +477,7 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); assert_eq!(ctx.get_ts_node_content(node), Some("fro")); assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select)); diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index af337357..4104556e 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -41,7 +41,10 @@ impl CompletionRelevance<'_> { } fn check_matches_query_input(&mut self, ctx: &CompletionContext) { - let node = ctx.ts_node.unwrap(); + let node = match ctx.node_under_cursor { + Some(node) => node, + None => return, + }; let content = match ctx.get_ts_node_content(node) { Some(c) => c, diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 8dcbfb1d..4a52817c 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -13,6 +13,7 @@ use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; use pgt_fs::{ConfigName, PgTPath}; +use pgt_text_size::{TextRange, TextSize}; use pgt_typecheck::TypecheckParams; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; @@ -535,13 +536,18 @@ impl Workspace for WorkspaceServer { .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let (statement, stmt_range, text) = match doc - .iter_statements_with_text_and_range() - .find(|(_, r, _)| r.contains(params.position)) - { - Some(s) => s, - None => return Ok(CompletionsResult::default()), - }; + let (statement, stmt_range, text) = + match doc.iter_statements_with_text_and_range().find(|(_, r, _)| { + let expanded_range = TextRange::new( + r.start(), + r.end().checked_add(TextSize::new(2)).unwrap_or(r.end()), + ); + + expanded_range.contains(params.position) + }) { + Some(s) => s, + None => return Ok(CompletionsResult::default()), + }; // `offset` is the position in the document, // but we need the position within the *statement*. From 4e5521cb2a7c18a9ecea72e8f70750cdc109ed49 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 9 Apr 2025 08:44:56 +0200 Subject: [PATCH 03/11] awesome! --- Cargo.lock | 16 ++- crates/pgt_completions/src/complete.rs | 125 +++++++++++++++++- crates/pgt_completions/src/context.rs | 57 ++------ crates/pgt_workspace/Cargo.toml | 1 + crates/pgt_workspace/src/workspace/server.rs | 77 +++++++++-- .../src/workspace/server/document.rs | 4 + 6 files changed, 219 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 30d59465..779d82b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1842,6 +1842,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -2361,7 +2370,7 @@ dependencies = [ "cc", "fs_extra", "glob", - "itertools", + "itertools 0.10.5", "prost", "prost-build", "serde", @@ -2759,6 +2768,7 @@ dependencies = [ "futures", "globset", "ignore", + "itertools 0.14.0", "pgt_analyse", "pgt_analyser", "pgt_completions", @@ -2999,7 +3009,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck", - "itertools", + "itertools 0.10.5", "log", "multimap", "once_cell", @@ -3019,7 +3029,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.90", diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 5505daa2..f18fc15f 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -21,8 +21,45 @@ pub struct CompletionParams<'a> { text = params.text, position = params.position.to_string() ))] -pub fn complete(params: CompletionParams) -> Vec { - let ctx = CompletionContext::new(¶ms); +pub fn complete(mut params: CompletionParams) -> Vec { + let should_adjust_params = params.tree.is_some() + && (cursor_inbetween_nodes(params.tree.unwrap(), params.position) + || cursor_prepared_to_write_token_after_last_node( + params.tree.unwrap(), + params.position, + )); + + let usable_sql = if should_adjust_params { + let pos: usize = params.position.into(); + + let mut mutated_sql = String::new(); + + for (idx, c) in params.text.chars().enumerate() { + if idx == pos { + mutated_sql.push_str("REPLACED_TOKEN "); + } + mutated_sql.push(c); + } + + mutated_sql + } else { + params.text + }; + + let usable_tree = if should_adjust_params { + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + parser.parse(usable_sql.clone(), None) + } else { + tracing::info!("We're reusing the previous tree."); + None + }; + + params.text = usable_sql; + + let ctx = CompletionContext::new(¶ms, usable_tree.as_ref().or(params.tree)); let mut builder = CompletionBuilder::new(); @@ -32,3 +69,87 @@ pub fn complete(params: CompletionParams) -> Vec { builder.finish() } + +fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { + let mut cursor = tree.walk(); + let mut node = tree.root_node(); + + loop { + let child_dx = cursor.goto_first_child_for_byte(position.into()); + if child_dx.is_none() { + break; + } + node = cursor.node(); + } + + let byte = position.into(); + + // Return true if the cursor is NOT within the node's bounds, INCLUSIVE + !(node.start_byte() <= byte && node.end_byte() >= byte) +} + +fn cursor_prepared_to_write_token_after_last_node( + tree: &tree_sitter::Tree, + position: TextSize, +) -> bool { + let cursor_pos: usize = position.into(); + cursor_pos == tree.root_node().end_byte() + 1 +} + +#[cfg(test)] +mod tests { + use pgt_text_size::TextSize; + + use crate::complete::{cursor_inbetween_nodes, cursor_prepared_to_write_token_after_last_node}; + + #[test] + fn test_cursor_inbetween_nodes() { + let input = "select from users;"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select | from users; + assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7))); + + // select |from users; + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(8))); + + // select| from users; + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(6))); + } + + #[test] + fn test_cursor_after_nodes() { + let input = "select * from "; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select * from|; <-- still on previous token + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(14) + )); + + // select * from |; <-- too far off + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(16) + )); + + // select * from |; <-- just right + assert!(cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(15) + )); + } +} diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 776e97d1..7ee88a63 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -50,28 +50,12 @@ impl TryFrom for ClauseType { pub(crate) struct CompletionContext<'a> { pub node_under_cursor: Option>, - pub previous_node: Option>, pub tree: Option<&'a tree_sitter::Tree>, pub text: &'a str, pub schema_cache: &'a SchemaCache, pub position: usize, - /// If the cursor of the user is offset to the right of the statement, - /// we'll have to move it back to the last node, otherwise, tree-sitter will break. - /// However, knowing that the user is typing on the "next" node lets us prioritize different completion results. - /// We consider an offset of up to two characters as valid. - /// - /// Example: - /// - /// ``` - /// select * from {} - /// ``` - /// - /// We'll adjust the cursor position so it lies on the "from" token – but we're looking - /// for table completions. - pub cursor_offset_from_end: bool, - pub schema_name: Option, pub wrapping_clause_type: Option, pub is_invocation: bool, @@ -81,14 +65,12 @@ pub(crate) struct CompletionContext<'a> { } impl<'a> CompletionContext<'a> { - pub fn new(params: &'a CompletionParams) -> Self { + pub fn new(params: &'a CompletionParams, usable_tree: Option<&'a tree_sitter::Tree>) -> Self { let mut ctx = Self { - tree: params.tree, + tree: usable_tree, text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), - cursor_offset_from_end: false, - previous_node: None, node_under_cursor: None, schema_name: None, wrapping_clause_type: None, @@ -97,7 +79,10 @@ impl<'a> CompletionContext<'a> { mentioned_relations: HashMap::new(), }; + tracing::warn!("gathering tree context"); ctx.gather_tree_context(); + + tracing::warn!("gathering info from ts query"); ctx.gather_info_from_ts_queries(); ctx @@ -164,14 +149,10 @@ impl<'a> CompletionContext<'a> { * `select * from use {}` becomes `select * from use{}`. */ let current_node = cursor.node(); - let position_cache = self.position.clone(); while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { self.position -= 1; } - let cursor_offset = position_cache - self.position; - self.cursor_offset_from_end = cursor_offset > 0 && cursor_offset <= 2; - self.gather_context_from_node(cursor, current_node); } @@ -223,23 +204,11 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - if self.cursor_offset_from_end { + if self.get_ts_node_content(current_node).unwrap() == "REPLACED_TOKEN" { self.node_under_cursor = None; - self.previous_node = Some(current_node); } else { - // for the previous node, either select the previous sibling, - // or collect the parent's previous sibling's last child. - let previous = match current_node.prev_sibling() { - Some(n) => Some(n), - None => { - let sib_of_parent = parent_node.prev_sibling(); - sib_of_parent.and_then(|p| p.children(&mut cursor).last()) - } - }; self.node_under_cursor = Some(current_node); - self.previous_node = previous; } - return; } @@ -305,7 +274,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); } @@ -337,7 +306,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); } @@ -371,7 +340,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); assert_eq!(ctx.is_invocation, is_invocation); } @@ -396,7 +365,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); let node = ctx.node_under_cursor.unwrap(); @@ -424,7 +393,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); let node = ctx.node_under_cursor.unwrap(); @@ -450,7 +419,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); let node = ctx.node_under_cursor.unwrap(); @@ -475,7 +444,7 @@ mod tests { schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms, Some(&tree)); let node = ctx.node_under_cursor.unwrap(); diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index 7df42b19..862184a6 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -18,6 +18,7 @@ futures = "0.3.31" globset = "0.4.16" ignore = { workspace = true } +itertools = { version = "0.14.0" } pgt_analyse = { workspace = true, features = ["serde"] } pgt_analyser = { workspace = true } pgt_completions = { workspace = true } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 4a52817c..a394c341 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -8,6 +8,7 @@ use db_connection::DbConnection; pub(crate) use document::StatementId; use document::{Document, Statement}; use futures::{StreamExt, stream}; +use itertools::Itertools; use pg_query::PgQueryStore; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; @@ -528,7 +529,10 @@ impl Workspace for WorkspaceServer { ) -> Result { let pool = match self.connection.read().unwrap().get_pool() { Some(pool) => pool, - None => return Ok(CompletionsResult::default()), + None => { + tracing::debug!("No connection to database. Skipping completions."); + return Ok(CompletionsResult::default()); + } }; let doc = self @@ -536,18 +540,67 @@ impl Workspace for WorkspaceServer { .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let (statement, stmt_range, text) = - match doc.iter_statements_with_text_and_range().find(|(_, r, _)| { - let expanded_range = TextRange::new( - r.start(), - r.end().checked_add(TextSize::new(2)).unwrap_or(r.end()), - ); + let count = doc.statement_count(); + + let maybe_statement = if count == 0 { + None + } else if count == 1 { + let (stmt, range, txt) = doc.iter_statements_with_text_and_range().next().unwrap(); + let expanded_range = TextRange::new( + range.start(), + range + .end() + .checked_add(TextSize::new(2)) + .unwrap_or(range.end()), + ); + if expanded_range.contains(params.position) { + Some((stmt, range, txt)) + } else { + None + } + } else { + let mut stmts = doc.iter_statements_with_text_and_range().tuple_windows(); + stmts.find(|((_, rcurrent, _), (_, rnext, _))| { + /* + * We allow an offset of two for the statement: + * + * (| is the user's cursor.) + * + * select * from | <-- we want to suggest items for the next token. + * + */ + let expanded_range = TextRange::new( + rcurrent.start(), + rcurrent + .end() + .checked_add(TextSize::new(2)) + .unwrap_or(rcurrent.end()), + ); + let is_within_range = expanded_range.contains(params.position); - expanded_range.contains(params.position) - }) { - Some(s) => s, - None => return Ok(CompletionsResult::default()), - }; + /* + * However, we do not allow this if the there the offset overlaps + * with an adjacent statement: + * + * select 1; |select 1; + */ + let overlaps_next = !rnext.contains(params.position); + + + tracing::warn!("Current range {:?}, next range {:?}, position: {:?}, contains range: {}, overlaps :{}", rcurrent, rnext, params.position, is_within_range, overlaps_next); + + is_within_range && !overlaps_next + + }).map(|(t1,_t2)| t1) + }; + + let (statement, stmt_range, text) = match maybe_statement { + Some(tuple) => tuple, + None => { + tracing::debug!("No matching statement found for completion."); + return Ok(CompletionsResult::default()); + } + }; // `offset` is the position in the document, // but we need the position within the *statement*. diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 9ef8c234..44079a61 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -104,6 +104,10 @@ impl Document { }) } + pub fn statement_count(&self) -> usize { + self.positions.len() + } + pub fn get_txt(&self, stmt_id: StatementId) -> Option { self.positions .iter() From 4f00abb992842ee2adefe3314d464f81ea986614 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 9 Apr 2025 09:24:00 +0200 Subject: [PATCH 04/11] refactor, terminate by semicolons --- crates/pgt_text_size/src/range.rs | 18 +++++ crates/pgt_workspace/src/workspace/server.rs | 80 +++++++++---------- .../src/workspace/server/document.rs | 10 +++ 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/crates/pgt_text_size/src/range.rs b/crates/pgt_text_size/src/range.rs index 95b0db58..3cfc3c96 100644 --- a/crates/pgt_text_size/src/range.rs +++ b/crates/pgt_text_size/src/range.rs @@ -281,6 +281,24 @@ impl TextRange { }) } + /// Expand the range's end by the given offset. + /// + /// # Examples + /// + /// ```rust + /// # use pgt_text_size::*; + /// assert_eq!( + /// TextRange::new(2.into(), 4.into()).checked_expand_end(16.into()).unwrap(), + /// TextRange::new(2.into(), 20.into()), + /// ); + /// ``` + #[inline] + pub fn checked_expand_end(self, offset: TextSize) -> Option { + Some(TextRange { + start: self.start, + end: self.end.checked_add(offset)?, + }) + } /// Subtract an offset from this range. /// /// Note that this is not appropriate for changing where a `TextRange` is diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index a394c341..de573ef0 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -541,61 +541,57 @@ impl Workspace for WorkspaceServer { .ok_or(WorkspaceError::not_found())?; let count = doc.statement_count(); + // no arms no cookies + if count == 0 { + return Ok(CompletionsResult::default()); + } - let maybe_statement = if count == 0 { - None - } else if count == 1 { + /* + * We allow an offset of two for the statement: + * + * select * from | <-- we want to suggest items for the next token. + * + * However, if the current statement is terminated by a semicolon, we don't apply any + * offset. + * + * select * from users; | <-- no autocompletions here. + */ + let matches_expanding_range = + |stmt_id: StatementId, range: &TextRange, position: TextSize| { + let measuring_range = if doc.is_terminated_by_semicolon(stmt_id).unwrap() { + *range + } else { + range.checked_expand_end(2.into()).unwrap_or(*range) + }; + measuring_range.contains(position) + }; + + let maybe_statement = if count == 1 { let (stmt, range, txt) = doc.iter_statements_with_text_and_range().next().unwrap(); - let expanded_range = TextRange::new( - range.start(), - range - .end() - .checked_add(TextSize::new(2)) - .unwrap_or(range.end()), - ); - if expanded_range.contains(params.position) { + if matches_expanding_range(stmt.id, range, params.position) { Some((stmt, range, txt)) } else { None } } else { - let mut stmts = doc.iter_statements_with_text_and_range().tuple_windows(); - stmts.find(|((_, rcurrent, _), (_, rnext, _))| { /* - * We allow an offset of two for the statement: - * - * (| is the user's cursor.) - * - * select * from | <-- we want to suggest items for the next token. + * If we have multiple statements, we want to make sure that we do not overlap + * with the next one. * + * select 1 |select 1; */ - let expanded_range = TextRange::new( - rcurrent.start(), - rcurrent - .end() - .checked_add(TextSize::new(2)) - .unwrap_or(rcurrent.end()), - ); - let is_within_range = expanded_range.contains(params.position); - - /* - * However, we do not allow this if the there the offset overlaps - * with an adjacent statement: - * - * select 1; |select 1; - */ - let overlaps_next = !rnext.contains(params.position); - - - tracing::warn!("Current range {:?}, next range {:?}, position: {:?}, contains range: {}, overlaps :{}", rcurrent, rnext, params.position, is_within_range, overlaps_next); - - is_within_range && !overlaps_next - - }).map(|(t1,_t2)| t1) + let mut stmts = doc.iter_statements_with_text_and_range().tuple_windows(); + stmts + .find(|((current_stmt, rcurrent, _), (_, rnext, _))| { + let overlaps_next = rnext.contains(params.position); + matches_expanding_range(current_stmt.id, &rcurrent, params.position) + && !overlaps_next + }) + .map(|t| t.0) }; let (statement, stmt_range, text) = match maybe_statement { - Some(tuple) => tuple, + Some(it) => it, None => { tracing::debug!("No matching statement found for completion."); return Ok(CompletionsResult::default()); diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 44079a61..a1bd74d4 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -117,6 +117,16 @@ impl Document { stmt.to_owned() }) } + + pub fn is_terminated_by_semicolon(&self, stmt_id: StatementId) -> Option { + self.positions + .iter() + .find(|pos| pos.0 == stmt_id) + .map(|(_, range)| { + let final_char = self.content.chars().nth(range.end().into()); + final_char == Some(';') + }) + } } pub(crate) struct IdGenerator { From 5c0a7a6a314f30e5cd1e2be0e831568bd2bae26c Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 11:56:45 +0200 Subject: [PATCH 05/11] =?UTF-8?q?got=20the=20tests=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_completions/src/complete.rs | 128 +------- crates/pgt_completions/src/context.rs | 119 +++++--- crates/pgt_completions/src/lib.rs | 1 + .../pgt_completions/src/providers/tables.rs | 1 + crates/pgt_completions/src/relevance.rs | 7 +- crates/pgt_completions/src/sanitization.rs | 284 ++++++++++++++++++ crates/pgt_completions/src/test_helper.rs | 4 +- 7 files changed, 372 insertions(+), 172 deletions(-) create mode 100644 crates/pgt_completions/src/sanitization.rs diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index f18fc15f..0d775aed 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -5,6 +5,7 @@ use crate::{ context::CompletionContext, item::CompletionItem, providers::{complete_columns, complete_functions, complete_tables}, + sanitization::SanitizedCompletionParams, }; pub const LIMIT: usize = 50; @@ -21,45 +22,16 @@ pub struct CompletionParams<'a> { text = params.text, position = params.position.to_string() ))] -pub fn complete(mut params: CompletionParams) -> Vec { - let should_adjust_params = params.tree.is_some() - && (cursor_inbetween_nodes(params.tree.unwrap(), params.position) - || cursor_prepared_to_write_token_after_last_node( - params.tree.unwrap(), - params.position, - )); - - let usable_sql = if should_adjust_params { - let pos: usize = params.position.into(); - - let mut mutated_sql = String::new(); - - for (idx, c) in params.text.chars().enumerate() { - if idx == pos { - mutated_sql.push_str("REPLACED_TOKEN "); - } - mutated_sql.push(c); +pub fn complete(params: CompletionParams) -> Vec { + let sanitized_params = match SanitizedCompletionParams::try_from(params) { + Ok(p) => p, + Err(err) => { + tracing::warn!("Not possible to get completions: {}", err); + return vec![]; } - - mutated_sql - } else { - params.text }; - let usable_tree = if should_adjust_params { - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - parser.parse(usable_sql.clone(), None) - } else { - tracing::info!("We're reusing the previous tree."); - None - }; - - params.text = usable_sql; - - let ctx = CompletionContext::new(¶ms, usable_tree.as_ref().or(params.tree)); + let ctx = CompletionContext::new(&sanitized_params); let mut builder = CompletionBuilder::new(); @@ -69,87 +41,3 @@ pub fn complete(mut params: CompletionParams) -> Vec { builder.finish() } - -fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { - let mut cursor = tree.walk(); - let mut node = tree.root_node(); - - loop { - let child_dx = cursor.goto_first_child_for_byte(position.into()); - if child_dx.is_none() { - break; - } - node = cursor.node(); - } - - let byte = position.into(); - - // Return true if the cursor is NOT within the node's bounds, INCLUSIVE - !(node.start_byte() <= byte && node.end_byte() >= byte) -} - -fn cursor_prepared_to_write_token_after_last_node( - tree: &tree_sitter::Tree, - position: TextSize, -) -> bool { - let cursor_pos: usize = position.into(); - cursor_pos == tree.root_node().end_byte() + 1 -} - -#[cfg(test)] -mod tests { - use pgt_text_size::TextSize; - - use crate::complete::{cursor_inbetween_nodes, cursor_prepared_to_write_token_after_last_node}; - - #[test] - fn test_cursor_inbetween_nodes() { - let input = "select from users;"; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let mut tree = parser.parse(input.to_string(), None).unwrap(); - - // select | from users; - assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7))); - - // select |from users; - assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(8))); - - // select| from users; - assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(6))); - } - - #[test] - fn test_cursor_after_nodes() { - let input = "select * from "; - - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let mut tree = parser.parse(input.to_string(), None).unwrap(); - - // select * from|; <-- still on previous token - assert!(!cursor_prepared_to_write_token_after_last_node( - &mut tree, - TextSize::new(14) - )); - - // select * from |; <-- too far off - assert!(!cursor_prepared_to_write_token_after_last_node( - &mut tree, - TextSize::new(16) - )); - - // select * from |; <-- just right - assert!(cursor_prepared_to_write_token_after_last_node( - &mut tree, - TextSize::new(15) - )); - } -} diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 7ee88a63..45e287b1 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -6,7 +6,7 @@ use pgt_treesitter_queries::{ queries::{self, QueryResult}, }; -use crate::CompletionParams; +use crate::sanitization::SanitizedCompletionParams; #[derive(Debug, PartialEq, Eq)] pub enum ClauseType { @@ -17,6 +17,12 @@ pub enum ClauseType { Delete, } +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum NodeText<'a> { + Replaced, + Original(&'a str), +} + impl TryFrom<&str> for ClauseType { type Error = String; @@ -51,7 +57,7 @@ impl TryFrom for ClauseType { pub(crate) struct CompletionContext<'a> { pub node_under_cursor: Option>, - pub tree: Option<&'a tree_sitter::Tree>, + pub tree: &'a tree_sitter::Tree, pub text: &'a str, pub schema_cache: &'a SchemaCache, pub position: usize, @@ -65,9 +71,9 @@ pub(crate) struct CompletionContext<'a> { } impl<'a> CompletionContext<'a> { - pub fn new(params: &'a CompletionParams, usable_tree: Option<&'a tree_sitter::Tree>) -> Self { + pub fn new(params: &'a SanitizedCompletionParams) -> Self { let mut ctx = Self { - tree: usable_tree, + tree: params.tree.as_ref(), text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), @@ -89,15 +95,10 @@ impl<'a> CompletionContext<'a> { } fn gather_info_from_ts_queries(&mut self) { - let tree = match self.tree.as_ref() { - None => return, - Some(t) => t, - }; - let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; - let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql); executor.add_query_results::(); @@ -124,17 +125,19 @@ impl<'a> CompletionContext<'a> { } } - pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { + pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option> { let source = self.text; - ts_node.utf8_text(source.as_bytes()).ok() + ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { + if SanitizedCompletionParams::is_sanitized_token(txt) { + NodeText::Replaced + } else { + NodeText::Original(txt) + } + }) } fn gather_tree_context(&mut self) { - if self.tree.is_none() { - return; - } - - let mut cursor = self.tree.as_ref().unwrap().root_node().walk(); + let mut cursor = self.tree.root_node().walk(); /* * The head node of any treesitter tree is always the "PROGRAM" node. @@ -181,11 +184,16 @@ impl<'a> CompletionContext<'a> { match current_node.kind() { "object_reference" => { - let txt = self.get_ts_node_content(current_node); - if let Some(txt) = txt { - let parts: Vec<&str> = txt.split('.').collect(); - if parts.len() == 2 { - self.schema_name = Some(parts[0].to_string()); + let content = self.get_ts_node_content(current_node); + if let Some(node_txt) = content { + match node_txt { + NodeText::Original(txt) => { + let parts: Vec<&str> = txt.split('.').collect(); + if parts.len() == 2 { + self.schema_name = Some(parts[0].to_string()); + } + } + NodeText::Replaced => {} } } } @@ -204,7 +212,10 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - if self.get_ts_node_content(current_node).unwrap() == "REPLACED_TOKEN" { + if matches!( + self.get_ts_node_content(current_node).unwrap(), + NodeText::Replaced + ) { self.node_under_cursor = None; } else { self.node_under_cursor = Some(current_node); @@ -220,7 +231,8 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { use crate::{ - context::{ClauseType, CompletionContext}, + context::{ClauseType, CompletionContext, NodeText}, + sanitization::SanitizedCompletionParams, test_helper::{CURSOR_POS, get_text_and_position}, }; @@ -267,14 +279,14 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); } @@ -299,14 +311,14 @@ mod tests { let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); } @@ -333,14 +345,14 @@ mod tests { let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.is_invocation, is_invocation); } @@ -358,18 +370,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("select")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("select")) + ); assert_eq!( ctx.wrapping_clause_type, @@ -386,18 +401,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("from")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("from")) + ); assert_eq!( ctx.wrapping_clause_type, Some(crate::context::ClauseType::From) @@ -412,18 +430,18 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("")); + assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original(""))); assert_eq!(ctx.wrapping_clause_type, None); } @@ -437,18 +455,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: Some(&tree), + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; - let ctx = CompletionContext::new(¶ms, Some(&tree)); + let ctx = CompletionContext::new(¶ms); let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("fro")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("fro")) + ); assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select)); } } diff --git a/crates/pgt_completions/src/lib.rs b/crates/pgt_completions/src/lib.rs index 62470ff4..c37c4d0f 100644 --- a/crates/pgt_completions/src/lib.rs +++ b/crates/pgt_completions/src/lib.rs @@ -4,6 +4,7 @@ mod context; mod item; mod providers; mod relevance; +mod sanitization; #[cfg(test)] mod test_helper; diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 18fce14b..f245abc8 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -81,6 +81,7 @@ mod tests { for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; let params = get_test_params(&tree, &cache, query.as_str().into()); + println!("{}, {}", ¶ms.text, ¶ms.position); let items = complete(params); assert!(!items.is_empty()); diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index 4104556e..e3fa5918 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext}; +use crate::context::{ClauseType, CompletionContext, NodeText}; #[derive(Debug)] pub(crate) enum CompletionRelevanceData<'a> { @@ -47,7 +47,10 @@ impl CompletionRelevance<'_> { }; let content = match ctx.get_ts_node_content(node) { - Some(c) => c, + Some(c) => match c { + NodeText::Original(s) => s, + NodeText::Replaced => return, + }, None => return, }; diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs new file mode 100644 index 00000000..fb1a73b6 --- /dev/null +++ b/crates/pgt_completions/src/sanitization.rs @@ -0,0 +1,284 @@ +use std::borrow::Cow; + +use pgt_text_size::TextSize; + +use crate::CompletionParams; + +pub(crate) struct SanitizedCompletionParams<'a> { + pub position: TextSize, + pub text: String, + pub schema: &'a pgt_schema_cache::SchemaCache, + pub tree: Cow<'a, tree_sitter::Tree>, +} + +impl<'larger, 'smaller> TryFrom> for SanitizedCompletionParams<'smaller> +where + 'larger: 'smaller, +{ + type Error = String; + + fn try_from(params: CompletionParams<'larger>) -> Result { + let tree = match ¶ms.tree { + Some(tree) => tree, + None => return Err("Tree required for autocompletions.".to_string()), + }; + + if cursor_inbetween_nodes(tree, params.position) + || cursor_prepared_to_write_token_after_last_node(tree, params.position) + { + Ok(SanitizedCompletionParams::with_adjusted_sql(params)) + } else { + Ok(SanitizedCompletionParams::unadjusted(params)) + } + } +} + +static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; + +impl<'larger, 'smaller> SanitizedCompletionParams<'smaller> +where + 'larger: 'smaller, +{ + fn with_adjusted_sql(params: CompletionParams<'larger>) -> Self { + let cursor_pos: usize = params.position.into(); + let mut sql = String::new(); + + for (idx, c) in params.text.chars().enumerate() { + if idx == cursor_pos { + sql.push_str(SANITIZED_TOKEN); + sql.push(' '); + } + sql.push(c); + } + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + let tree = parser.parse(sql.clone(), None).unwrap(); + + Self { + position: params.position, + text: sql, + schema: params.schema, + tree: Cow::Owned(tree), + } + } + fn unadjusted(params: CompletionParams<'larger>) -> Self { + Self { + position: params.position, + text: params.text.clone(), + schema: params.schema, + tree: Cow::Borrowed(params.tree.unwrap()), + } + } + + pub fn is_sanitized_token(txt: &str) -> bool { + txt == SANITIZED_TOKEN + } +} + +/// Checks if the cursor is positioned inbetween two SQL nodes. +/// +/// ```sql +/// select| from users; -- cursor "touches" select node. returns false. +/// select |from users; -- cursor "touches" from node. returns false. +/// select | from users; -- cursor is between select and from nodes. returns true. +/// ``` +fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { + let mut cursor = tree.walk(); + let mut leaf_node = tree.root_node(); + + let byte = position.into(); + + // if the cursor escapes the root node, it can't be between nodes. + if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { + return false; + } + + /* + * Get closer and closer to the leaf node, until + * a) there is no more child *for the node* or + * b) there is no more child *under the cursor*. + */ + loop { + let child_idx = cursor.goto_first_child_for_byte(position.into()); + if child_idx.is_none() { + break; + } + leaf_node = cursor.node(); + } + + let cursor_on_leafnode = byte >= leaf_node.start_byte() && leaf_node.end_byte() >= byte; + + /* + * The cursor is inbetween nodes if it is not within the range + * of a leaf node. + */ + !cursor_on_leafnode +} + +/// Checks if the cursor is positioned after the last node, +/// ready to write the next token: +/// +/// ```sql +/// select * from | -- ready to write! +/// select * from| -- user still needs to type a space +/// select * from | -- too far off. +/// ``` +fn cursor_prepared_to_write_token_after_last_node( + tree: &tree_sitter::Tree, + position: TextSize, +) -> bool { + let cursor_pos: usize = position.into(); + cursor_pos == tree.root_node().end_byte() + 1 +} + +fn cursor_before_or_on_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { + let mut cursor = tree.walk(); + let mut leaf_node = tree.root_node(); + + let byte: usize = position.into(); + + // if the cursor escapes the root node, it can't be between nodes. + if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { + return false; + } + + loop { + let child_idx = cursor.goto_first_child_for_byte(position.into()); + if child_idx.is_none() { + break; + } + leaf_node = cursor.node(); + } + + // The semicolon node is on the same level as the statement: + // + // program [0..26] + // statement [0..19] + // ; [25..26] + // + // However, if we search for position 21, we'll still land on the semi node. + // We must manually verify that the cursor is between the statement and the semi nodes. + + // if the last node is not a semi, the statement is not completed. + if leaf_node.kind() != ";" { + return false; + } + + // not okay to be on the semi. + if byte == leaf_node.start_byte() { + return false; + } + + leaf_node + .prev_named_sibling() + .map(|n| n.end_byte() < byte) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use pgt_text_size::TextSize; + + use crate::sanitization::{ + cursor_before_or_on_semicolon, cursor_inbetween_nodes, + cursor_prepared_to_write_token_after_last_node, + }; + + #[test] + fn test_cursor_inbetween_nodes() { + // note: two spaces between select and from. + let input = "select from users;"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select | from users; <-- just right, one space after select token, one space before from + assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7))); + + // select| from users; <-- still on select token + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(6))); + + // select |from users; <-- already on from token + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(8))); + + // select from users;| + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(19))); + } + + #[test] + fn test_cursor_after_nodes() { + let input = "select * from"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select * from| <-- still on previous token + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(13) + )); + + // select * from | <-- too far off, two spaces afterward + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(15) + )); + + // select * |from <-- it's within + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(9) + )); + + // select * from | <-- just right + assert!(cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(14) + )); + } + + #[test] + fn test_cursor_before_semicolon() { + // Idx "13" is the exlusive end of `select * from` (first space after from) + // Idx "18" is right where the semi is + let input = "select * from ;"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select * from ;| <-- it's after the statement + assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(19))); + + // select * from| ; <-- still touches the from + assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(13))); + + // not okay to be ON the semi. + // select * from |; + assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(18))); + + // anything is fine here + // select * from | ; + // select * from | ; + // select * from | ; + // select * from |; + assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(14))); + assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(15))); + assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(16))); + assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(17))); + } +} diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index a54aacbd..af49893f 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -15,9 +15,11 @@ impl From<&str> for InputQuery { fn from(value: &str) -> Self { let position = value .find(CURSOR_POS) - .map(|p| p.saturating_sub(1)) + .map(|p| p.saturating_add(1)) .expect("Insert Cursor Position into your Query."); + println!("{}", position); + InputQuery { sql: value.replace(CURSOR_POS, ""), position, From cf61598bd6dc75236bd9ec3ef30d56bbe9d31921 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 12:06:38 +0200 Subject: [PATCH 06/11] hell yeah --- .../pgt_completions/src/providers/tables.rs | 7 ++- crates/pgt_completions/src/sanitization.rs | 19 ++++---- crates/pgt_completions/src/test_helper.rs | 43 +++++++++++++++++-- 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index f245abc8..2074a4f1 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -73,15 +73,14 @@ mod tests { "#; let test_cases = vec![ - // (format!("select * from us{}", CURSOR_POS), "users"), - // (format!("select * from em{}", CURSOR_POS), "emails"), - (format!("select * from {}", CURSOR_POS), "addresses"), + (format!("select * from u{}", CURSOR_POS), "users"), + (format!("select * from e{}", CURSOR_POS), "emails"), + (format!("select * from a{}", CURSOR_POS), "addresses"), ]; for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; let params = get_test_params(&tree, &cache, query.as_str().into()); - println!("{}, {}", ¶ms.text, ¶ms.position); let items = complete(params); assert!(!items.is_empty()); diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index fb1a73b6..5f964813 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -25,6 +25,7 @@ where if cursor_inbetween_nodes(tree, params.position) || cursor_prepared_to_write_token_after_last_node(tree, params.position) + || cursor_before_semicolon(tree, params.position) { Ok(SanitizedCompletionParams::with_adjusted_sql(params)) } else { @@ -134,7 +135,7 @@ fn cursor_prepared_to_write_token_after_last_node( cursor_pos == tree.root_node().end_byte() + 1 } -fn cursor_before_or_on_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { +fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { let mut cursor = tree.walk(); let mut leaf_node = tree.root_node(); @@ -183,7 +184,7 @@ mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_or_on_semicolon, cursor_inbetween_nodes, + cursor_before_semicolon, cursor_inbetween_nodes, cursor_prepared_to_write_token_after_last_node, }; @@ -262,23 +263,23 @@ mod tests { let mut tree = parser.parse(input.to_string(), None).unwrap(); // select * from ;| <-- it's after the statement - assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(19))); + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(19))); // select * from| ; <-- still touches the from - assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(13))); + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(13))); // not okay to be ON the semi. // select * from |; - assert!(!cursor_before_or_on_semicolon(&mut tree, TextSize::new(18))); + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(18))); // anything is fine here // select * from | ; // select * from | ; // select * from | ; // select * from |; - assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(14))); - assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(15))); - assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(16))); - assert!(cursor_before_or_on_semicolon(&mut tree, TextSize::new(17))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(14))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(15))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(16))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(17))); } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index af49893f..59e9a5c5 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -15,11 +15,8 @@ impl From<&str> for InputQuery { fn from(value: &str) -> Self { let position = value .find(CURSOR_POS) - .map(|p| p.saturating_add(1)) .expect("Insert Cursor Position into your Query."); - println!("{}", position); - InputQuery { sql: value.replace(CURSOR_POS, ""), position, @@ -76,3 +73,43 @@ pub(crate) fn get_test_params<'a>( text, } } + +#[cfg(test)] +mod tests { + use crate::test_helper::CURSOR_POS; + + use super::InputQuery; + + #[test] + fn input_query_should_extract_correct_position() { + struct TestCase { + query: String, + expected_pos: usize, + expected_sql_len: usize, + } + + let cases = vec![ + TestCase { + query: format!("select * from{}", CURSOR_POS), + expected_pos: 13, + expected_sql_len: 13, + }, + TestCase { + query: format!("{}select * from", CURSOR_POS), + expected_pos: 0, + expected_sql_len: 13, + }, + TestCase { + query: format!("select {} from", CURSOR_POS), + expected_pos: 7, + expected_sql_len: 12, + }, + ]; + + for case in cases { + let query = InputQuery::from(case.query.as_str()); + assert_eq!(query.position, case.expected_pos); + assert_eq!(query.sql.len(), case.expected_sql_len); + } + } +} From c8a60823eea427399351ff16b9640dd08d01cdbd Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 12:20:40 +0200 Subject: [PATCH 07/11] give it a spin --- .../pgt_workspace/src/features/completions.rs | 58 +++++++++++++++- crates/pgt_workspace/src/workspace.rs | 2 +- crates/pgt_workspace/src/workspace/server.rs | 67 ++----------------- 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index 8fb13313..f148680e 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -1,6 +1,9 @@ +use itertools::Itertools; use pgt_completions::CompletionItem; use pgt_fs::PgTPath; -use pgt_text_size::TextSize; +use pgt_text_size::{TextRange, TextSize}; + +use crate::workspace::{Document, Statement, StatementId}; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -24,3 +27,56 @@ impl IntoIterator for CompletionsResult { self.items.into_iter() } } + +pub(crate) fn get_statement_for_completions<'a>( + doc: &'a Document, + position: TextSize, +) -> Option<(Statement, &'a TextRange, &'a str)> { + let count = doc.statement_count(); + // no arms no cookies + if count == 0 { + return None; + } + + /* + * We allow an offset of two for the statement: + * + * select * from | <-- we want to suggest items for the next token. + * + * However, if the current statement is terminated by a semicolon, we don't apply any + * offset. + * + * select * from users; | <-- no autocompletions here. + */ + let matches_expanding_range = |stmt_id: StatementId, range: &TextRange, position: TextSize| { + let measuring_range = if doc.is_terminated_by_semicolon(stmt_id).unwrap() { + *range + } else { + range.checked_expand_end(2.into()).unwrap_or(*range) + }; + measuring_range.contains(position) + }; + + if count == 1 { + let (stmt, range, txt) = doc.iter_statements_with_text_and_range().next().unwrap(); + if matches_expanding_range(stmt.id, range, position) { + Some((stmt, range, txt)) + } else { + None + } + } else { + /* + * If we have multiple statements, we want to make sure that we do not overlap + * with the next one. + * + * select 1 |select 1; + */ + let mut stmts = doc.iter_statements_with_text_and_range().tuple_windows(); + stmts + .find(|((current_stmt, rcurrent, _), (_, rnext, _))| { + let overlaps_next = rnext.contains(position); + matches_expanding_range(current_stmt.id, rcurrent, position) && !overlaps_next + }) + .map(|t| t.0) + } +} diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 4a503d5d..d965c9d2 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -21,7 +21,7 @@ use crate::{ mod client; mod server; -pub(crate) use server::StatementId; +pub(crate) use server::document::{Document, Statement, StatementId}; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index de573ef0..7e70afd5 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -5,16 +5,13 @@ use async_helper::run_async; use change::StatementChange; use dashmap::DashMap; use db_connection::DbConnection; -pub(crate) use document::StatementId; use document::{Document, Statement}; use futures::{StreamExt, stream}; -use itertools::Itertools; use pg_query::PgQueryStore; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; use pgt_fs::{ConfigName, PgTPath}; -use pgt_text_size::{TextRange, TextSize}; use pgt_typecheck::TypecheckParams; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; @@ -29,7 +26,7 @@ use crate::{ self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory, ExecuteStatementParams, ExecuteStatementResult, }, - completions::{CompletionsResult, GetCompletionsParams}, + completions::{self, CompletionsResult, GetCompletionsParams}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, }, settings::{Settings, SettingsHandle, SettingsHandleMut}, @@ -44,7 +41,7 @@ mod analyser; mod async_helper; mod change; mod db_connection; -mod document; +pub(crate) mod document; mod migration; mod pg_query; mod schema_cache_manager; @@ -540,64 +537,12 @@ impl Workspace for WorkspaceServer { .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; - let count = doc.statement_count(); - // no arms no cookies - if count == 0 { - return Ok(CompletionsResult::default()); - } - - /* - * We allow an offset of two for the statement: - * - * select * from | <-- we want to suggest items for the next token. - * - * However, if the current statement is terminated by a semicolon, we don't apply any - * offset. - * - * select * from users; | <-- no autocompletions here. - */ - let matches_expanding_range = - |stmt_id: StatementId, range: &TextRange, position: TextSize| { - let measuring_range = if doc.is_terminated_by_semicolon(stmt_id).unwrap() { - *range - } else { - range.checked_expand_end(2.into()).unwrap_or(*range) - }; - measuring_range.contains(position) + let (statement, stmt_range, text) = + match completions::get_statement_for_completions(&doc, params.position) { + None => return Ok(CompletionsResult::default()), + Some(s) => s, }; - let maybe_statement = if count == 1 { - let (stmt, range, txt) = doc.iter_statements_with_text_and_range().next().unwrap(); - if matches_expanding_range(stmt.id, range, params.position) { - Some((stmt, range, txt)) - } else { - None - } - } else { - /* - * If we have multiple statements, we want to make sure that we do not overlap - * with the next one. - * - * select 1 |select 1; - */ - let mut stmts = doc.iter_statements_with_text_and_range().tuple_windows(); - stmts - .find(|((current_stmt, rcurrent, _), (_, rnext, _))| { - let overlaps_next = rnext.contains(params.position); - matches_expanding_range(current_stmt.id, &rcurrent, params.position) - && !overlaps_next - }) - .map(|t| t.0) - }; - - let (statement, stmt_range, text) = match maybe_statement { - Some(it) => it, - None => { - tracing::debug!("No matching statement found for completion."); - return Ok(CompletionsResult::default()); - } - }; - // `offset` is the position in the document, // but we need the position within the *statement*. let position = params.position - stmt_range.start(); From 5dff3a69794309c07af75c55d00468b4277866df Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 12:41:07 +0200 Subject: [PATCH 08/11] yeah --- .../pgt_workspace/src/features/completions.rs | 114 ++++++++++++++++++ .../src/workspace/server/document.rs | 3 +- 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index f148680e..bcbb8084 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -80,3 +80,117 @@ pub(crate) fn get_statement_for_completions<'a>( .map(|t| t.0) } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use pgt_fs::PgTPath; + use pgt_text_size::TextSize; + + use crate::workspace::Document; + + use super::get_statement_for_completions; + + static CURSOR_POSITION: &str = "€"; + + fn get_doc_and_pos(sql: &str) -> (Document, TextSize) { + let pos = sql + .find(CURSOR_POSITION) + .expect("Please add cursor position to test sql"); + + let pos: u32 = pos.try_into().unwrap(); + + ( + Document::new( + PgTPath::new("test.sql"), + sql.replace(CURSOR_POSITION, "").into(), + 5, + ), + TextSize::new(pos), + ) + } + + #[test] + fn finds_matching_statement() { + let sql = format!( + r#" + select * from users; + + update {}users set email = 'myemail@com'; + + select 1; + "#, + CURSOR_POSITION + ); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "update users set email = 'myemail@com';") + } + + #[test] + fn does_not_break_when_no_statements_exist() { + let sql = format!("{}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert_eq!(get_statement_for_completions(&doc, position), None); + } + + #[test] + fn does_not_return_overlapping_statements_if_too_close() { + let sql = format!("select * from {}select 1;", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + // make sure these are parsed as two + assert_eq!(doc.iter_statements().try_len().unwrap(), 2); + + assert_eq!(get_statement_for_completions(&doc, position), None); + } + + #[test] + fn is_fine_with_spaces() { + let sql = format!("select * from {} ;", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "select * from ;") + } + + #[test] + fn considers_offset() { + let sql = format!("select * from {}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "select * from") + } + + #[test] + fn does_not_consider_too_far_offset() { + let sql = format!("select * from {}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert_eq!(get_statement_for_completions(&doc, position), None); + } + + #[test] + fn does_not_consider_offset_if_statement_terminated_by_semi() { + let sql = format!("select * from users;{}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert_eq!(get_statement_for_completions(&doc, position), None); + } +} diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index a1bd74d4..2cd8ec5e 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -123,7 +123,8 @@ impl Document { .iter() .find(|pos| pos.0 == stmt_id) .map(|(_, range)| { - let final_char = self.content.chars().nth(range.end().into()); + let length: usize = range.end().into(); + let final_char = self.content.chars().nth(length - 1); final_char == Some(';') }) } From c4dee15f040b7401bb73be9c2101f9689aeb3b71 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 12:48:31 +0200 Subject: [PATCH 09/11] leave the logs to the beavers --- crates/pgt_completions/src/context.rs | 3 --- crates/pgt_completions/src/relevance.rs | 5 +---- postgrestools.jsonc | 2 +- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 45e287b1..a4578df8 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -85,10 +85,7 @@ impl<'a> CompletionContext<'a> { mentioned_relations: HashMap::new(), }; - tracing::warn!("gathering tree context"); ctx.gather_tree_context(); - - tracing::warn!("gathering info from ts query"); ctx.gather_info_from_ts_queries(); ctx diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index e3fa5918..9650a94d 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -57,10 +57,7 @@ impl CompletionRelevance<'_> { let name = match self.data { CompletionRelevanceData::Function(f) => f.name.as_str(), CompletionRelevanceData::Table(t) => t.name.as_str(), - CompletionRelevanceData::Column(c) => { - // - c.name.as_str() - } + CompletionRelevanceData::Column(c) => c.name.as_str(), }; if name.starts_with(content) { diff --git a/postgrestools.jsonc b/postgrestools.jsonc index 0ce2e44f..325c7861 100644 --- a/postgrestools.jsonc +++ b/postgrestools.jsonc @@ -17,7 +17,7 @@ // YOU CAN COMMENT ME OUT :) "db": { "host": "127.0.0.1", - "port": 54322, + "port": 5432, "username": "postgres", "password": "postgres", "database": "postgres", From 03a06b3cd2c655e1e16889e190bd2d03b28f8c4e Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Apr 2025 18:28:53 +0200 Subject: [PATCH 10/11] add benchmarks --- Cargo.lock | 157 ++++++++++- crates/pgt_completions/Cargo.toml | 5 + .../pgt_completions/benches/sanitization.rs | 249 ++++++++++++++++++ crates/pgt_completions/src/lib.rs | 1 + crates/pgt_completions/src/sanitization.rs | 5 + 5 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 crates/pgt_completions/benches/sanitization.rs diff --git a/Cargo.lock b/Cargo.lock index 779d82b6..e501d75b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,12 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -733,6 +739,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.3" @@ -766,6 +778,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -898,6 +937,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -954,6 +1029,12 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crunchy" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1513,6 +1594,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1572,6 +1663,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hermit-abi" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" + [[package]] name = "hex" version = "0.4.3" @@ -1821,6 +1918,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "is-terminal" +version = "0.4.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" +dependencies = [ + "hermit-abi 0.5.0", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "is_ci" version = "1.2.0" @@ -2238,6 +2346,12 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "option-ext" version = "0.2.0" @@ -2452,6 +2566,7 @@ name = "pgt_completions" version = "0.0.0" dependencies = [ "async-std", + "criterion", "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", @@ -2864,6 +2979,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.8.0" @@ -3009,7 +3152,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck", - "itertools 0.10.5", + "itertools 0.14.0", "log", "multimap", "once_cell", @@ -3029,7 +3172,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.90", @@ -4079,6 +4222,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/crates/pgt_completions/Cargo.toml b/crates/pgt_completions/Cargo.toml index 559639f3..a69ee75a 100644 --- a/crates/pgt_completions/Cargo.toml +++ b/crates/pgt_completions/Cargo.toml @@ -31,6 +31,7 @@ sqlx.workspace = true tokio = { version = "1.41.1", features = ["full"] } [dev-dependencies] +criterion = "0.5.1" pgt_test_utils.workspace = true [lib] @@ -38,3 +39,7 @@ doctest = false [features] schema = ["dep:schemars"] + +[[bench]] +harness = false +name = "sanitization" diff --git a/crates/pgt_completions/benches/sanitization.rs b/crates/pgt_completions/benches/sanitization.rs new file mode 100644 index 00000000..c21538de --- /dev/null +++ b/crates/pgt_completions/benches/sanitization.rs @@ -0,0 +1,249 @@ +use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use pgt_completions::{CompletionParams, benchmark_sanitization}; +use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextSize; + +static CURSOR_POS: &str = "€"; + +fn sql_and_pos(sql: &str) -> (String, usize) { + let pos = sql.find(CURSOR_POS).unwrap(); + (sql.replace(CURSOR_POS, ""), pos) +} + +fn get_tree(sql: &str) -> tree_sitter::Tree { + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + parser.parse(sql.to_string(), None).unwrap() +} + +fn to_params<'a>( + text: String, + tree: &'a tree_sitter::Tree, + pos: usize, + cache: &'a SchemaCache, +) -> CompletionParams<'a> { + let pos: u32 = pos.try_into().unwrap(); + CompletionParams { + position: TextSize::new(pos), + schema: &cache, + text, + tree: Some(tree), + } +} + +pub fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("small sql, adjusted", |b| { + let content = format!("select {} from users;", CURSOR_POS); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("mid sql, adjusted", |b| { + let content = format!( + r#"select + n.oid :: int8 as "id!", + n.nspname as name, + u.rolname as "owner!" +from + pg_namespace n, + {} +where + n.nspowner = u.oid + and ( + pg_has_role(n.nspowner, 'USAGE') + or has_schema_privilege(n.oid, 'CREATE, USAGE') + ) + and not pg_catalog.starts_with(n.nspname, 'pg_temp_') + and not pg_catalog.starts_with(n.nspname, 'pg_toast_temp_');"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("large sql, adjusted", |b| { + let content = format!( + r#"with + available_tables as ( + select + c.relname as table_name, + c.oid as table_oid, + c.relkind as class_kind, + n.nspname as schema_name + from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where + -- r: normal tables + -- v: views + -- m: materialized views + -- f: foreign tables + -- p: partitioned tables + c.relkind in ('r', 'v', 'm', 'f', 'p') + ), + available_indexes as ( + select + unnest (ix.indkey) as attnum, + ix.indisprimary as is_primary, + ix.indisunique as is_unique, + ix.indrelid as table_oid + from + {} + where + c.relkind = 'i' + ) +select + atts.attname as name, + ts.table_name, + ts.table_oid :: int8 as "table_oid!", + ts.class_kind :: char as "class_kind!", + ts.schema_name, + atts.atttypid :: int8 as "type_id!", + not atts.attnotnull as "is_nullable!", + nullif( + information_schema._pg_char_max_length (atts.atttypid, atts.atttypmod), + -1 + ) as varchar_length, + pg_get_expr (def.adbin, def.adrelid) as default_expr, + coalesce(ix.is_primary, false) as "is_primary_key!", + coalesce(ix.is_unique, false) as "is_unique!", + pg_catalog.col_description (ts.table_oid, atts.attnum) as comment +from + pg_catalog.pg_attribute atts + join available_tables ts on atts.attrelid = ts.table_oid + left join available_indexes ix on atts.attrelid = ix.table_oid + and atts.attnum = ix.attnum + left join pg_catalog.pg_attrdef def on atts.attrelid = def.adrelid + and atts.attnum = def.adnum +where + -- system columns, such as `cmax` or `tableoid`, have negative `attnum`s + atts.attnum >= 0; +"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("small sql, unadjusted", |b| { + let content = format!("select e{} from users;", CURSOR_POS); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("mid sql, unadjusted", |b| { + let content = format!( + r#"select + n.oid :: int8 as "id!", + n.nspname as name, + u.rolname as "owner!" +from + pg_namespace n, + pg_r{} +where + n.nspowner = u.oid + and ( + pg_has_role(n.nspowner, 'USAGE') + or has_schema_privilege(n.oid, 'CREATE, USAGE') + ) + and not pg_catalog.starts_with(n.nspname, 'pg_temp_') + and not pg_catalog.starts_with(n.nspname, 'pg_toast_temp_');"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("large sql, unadjusted", |b| { + let content = format!( + r#"with + available_tables as ( + select + c.relname as table_name, + c.oid as table_oid, + c.relkind as class_kind, + n.nspname as schema_name + from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where + -- r: normal tables + -- v: views + -- m: materialized views + -- f: foreign tables + -- p: partitioned tables + c.relkind in ('r', 'v', 'm', 'f', 'p') + ), + available_indexes as ( + select + unnest (ix.indkey) as attnum, + ix.indisprimary as is_primary, + ix.indisunique as is_unique, + ix.indrelid as table_oid + from + pg_catalog.pg_class c + join pg_catalog.pg_index ix on c.oid = ix.indexrelid + where + c.relkind = 'i' + ) +select + atts.attname as name, + ts.table_name, + ts.table_oid :: int8 as "table_oid!", + ts.class_kind :: char as "class_kind!", + ts.schema_name, + atts.atttypid :: int8 as "type_id!", + not atts.attnotnull as "is_nullable!", + nullif( + information_schema._pg_char_max_length (atts.atttypid, atts.atttypmod), + -1 + ) as varchar_length, + pg_get_expr (def.adbin, def.adrelid) as default_expr, + coalesce(ix.is_primary, false) as "is_primary_key!", + coalesce(ix.is_unique, false) as "is_unique!", + pg_catalog.col_description (ts.table_oid, atts.attnum) as comment +from + pg_catalog.pg_attribute atts + join available_tables ts on atts.attrelid = ts.table_oid + left join available_indexes ix on atts.attrelid = ix.table_oid + and atts.attnum = ix.attnum + left join pg_catalog.pg_attrdef def on atts.attrelid = def.adrelid + and atts.attnum = def.adnum +where + -- system columns, such as `cmax` or `tableoid`, have negative `attnum`s + atts.attnum >= 0 +order by + sch{} "#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/crates/pgt_completions/src/lib.rs b/crates/pgt_completions/src/lib.rs index c37c4d0f..f8ca1a55 100644 --- a/crates/pgt_completions/src/lib.rs +++ b/crates/pgt_completions/src/lib.rs @@ -11,3 +11,4 @@ mod test_helper; pub use complete::*; pub use item::*; +pub use sanitization::*; diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 5f964813..1171da55 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -11,6 +11,11 @@ pub(crate) struct SanitizedCompletionParams<'a> { pub tree: Cow<'a, tree_sitter::Tree>, } +pub fn benchmark_sanitization(params: CompletionParams) -> String { + let params: SanitizedCompletionParams = params.try_into().unwrap(); + params.text +} + impl<'larger, 'smaller> TryFrom> for SanitizedCompletionParams<'smaller> where 'larger: 'smaller, From d01c79d0d0cacb9a0895c24814022549b9ab01e9 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 12 Apr 2025 14:14:44 +0200 Subject: [PATCH 11/11] cant fail --- crates/pgt_completions/src/complete.rs | 8 +------- crates/pgt_completions/src/sanitization.rs | 10 ++++------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 1a6a7bdd..ec1232a5 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -23,13 +23,7 @@ pub struct CompletionParams<'a> { position = params.position.to_string() ))] pub fn complete(params: CompletionParams) -> Vec { - let sanitized_params = match SanitizedCompletionParams::try_from(params) { - Ok(p) => p, - Err(err) => { - tracing::warn!("Not possible to get completions: {}", err); - return vec![]; - } - }; + let sanitized_params = SanitizedCompletionParams::from(params); let ctx = CompletionContext::new(&sanitized_params); diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 7859b41a..5ad8ba0e 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -16,20 +16,18 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String { params.text } -impl<'larger, 'smaller> TryFrom> for SanitizedCompletionParams<'smaller> +impl<'larger, 'smaller> From> for SanitizedCompletionParams<'smaller> where 'larger: 'smaller, { - type Error = String; - - fn try_from(params: CompletionParams<'larger>) -> Result { + fn from(params: CompletionParams<'larger>) -> Self { if cursor_inbetween_nodes(params.tree, params.position) || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) || cursor_before_semicolon(params.tree, params.position) { - Ok(SanitizedCompletionParams::with_adjusted_sql(params)) + SanitizedCompletionParams::with_adjusted_sql(params) } else { - Ok(SanitizedCompletionParams::unadjusted(params)) + SanitizedCompletionParams::unadjusted(params) } } }