Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve auto completions #310

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/pgt_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pgt_treesitter_queries.workspace = true
schemars = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tracing = { workspace = true }
tree-sitter.workspace = true
tree_sitter_sql.workspace = true

Expand Down
4 changes: 4 additions & 0 deletions crates/pgt_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ pub struct CompletionParams<'a> {
pub tree: Option<&'a tree_sitter::Tree>,
}

#[tracing::instrument(level = "debug", skip_all, fields(
text = params.text,
position = params.position.to_string()
))]
pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
let ctx = CompletionContext::new(&params);

Expand Down
63 changes: 51 additions & 12 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,29 @@ impl TryFrom<String> for ClauseType {
}

pub(crate) struct CompletionContext<'a> {
pub ts_node: Option<tree_sitter::Node<'a>>,
pub node_under_cursor: Option<tree_sitter::Node<'a>>,
pub previous_node: Option<tree_sitter::Node<'a>>,

pub tree: Option<&'a tree_sitter::Tree>,
pub text: &'a str,
pub schema_cache: &'a SchemaCache,
pub position: usize,

/// If the cursor of the user is offset to the right of the statement,
/// we'll have to move it back to the last node, otherwise, tree-sitter will break.
/// However, knowing that the user is typing on the "next" node lets us prioritize different completion results.
/// We consider an offset of up to two characters as valid.
///
/// Example:
///
/// ```
/// select * from {}
/// ```
///
/// We'll adjust the cursor position so it lies on the "from" token – but we're looking
/// for table completions.
pub cursor_offset_from_end: bool,

pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
Expand All @@ -70,7 +87,9 @@ impl<'a> CompletionContext<'a> {
text: &params.text,
schema_cache: params.schema,
position: usize::from(params.position),
ts_node: None,
cursor_offset_from_end: false,
previous_node: None,
node_under_cursor: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
Expand Down Expand Up @@ -145,30 +164,34 @@ impl<'a> CompletionContext<'a> {
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node = cursor.node();
let position_cache = self.position.clone();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

let cursor_offset = position_cache - self.position;
self.cursor_offset_from_end = cursor_offset > 0 && cursor_offset <= 2;

self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node: tree_sitter::Node<'a>,
parent_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
if current_node.kind() == parent_node.kind() {
self.node_under_cursor = Some(current_node);
return;
}

match previous_node.kind() {
match parent_node.kind() {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.range());
self.wrapping_statement_range = Some(parent_node.range());
}
"invocation" => self.is_invocation = true,

Expand Down Expand Up @@ -200,7 +223,23 @@ impl<'a> CompletionContext<'a> {

// We have arrived at the leaf node
if current_node.child_count() == 0 {
self.ts_node = Some(current_node);
if self.cursor_offset_from_end {
self.node_under_cursor = None;
self.previous_node = Some(current_node);
} else {
// for the previous node, either select the previous sibling,
// or collect the parent's previous sibling's last child.
let previous = match current_node.prev_sibling() {
Some(n) => Some(n),
None => {
let sib_of_parent = parent_node.prev_sibling();
sib_of_parent.and_then(|p| p.children(&mut cursor).last())
}
};
self.node_under_cursor = Some(current_node);
self.previous_node = previous;
}

return;
}

Expand Down Expand Up @@ -359,7 +398,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("select"));

Expand Down Expand Up @@ -387,7 +426,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("from"));
assert_eq!(
Expand All @@ -413,7 +452,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some(""));
assert_eq!(ctx.wrapping_clause_type, None);
Expand All @@ -438,7 +477,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("fro"));
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
Expand Down
82 changes: 78 additions & 4 deletions crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,86 @@ mod tests {
let params = get_test_params(&tree, &cache, case.get_input_query());
let mut items = complete(params);

let _ = items.split_off(3);
let _ = items.split_off(6);

items.sort_by(|a, b| a.label.cmp(&b.label));
#[derive(Eq, PartialEq, Debug)]
struct LabelAndDesc {
label: String,
desc: String,
}

let labels: Vec<LabelAndDesc> = items
.into_iter()
.map(|c| LabelAndDesc {
label: c.label,
desc: c.description,
})
.collect();

let expected = vec![
("name", "Table: public.users"),
("narrator", "Table: public.audio_books"),
("narrator_id", "Table: private.audio_books"),
("name", "Schema: pg_catalog"),
("nameconcatoid", "Schema: pg_catalog"),
("nameeq", "Schema: pg_catalog"),
]
.into_iter()
.map(|(label, schema)| LabelAndDesc {
label: label.into(),
desc: schema.into(),
})
.collect::<Vec<LabelAndDesc>>();

assert_eq!(labels, expected);
}

#[tokio::test]
async fn suggests_relevant_columns_without_letters() {
let setup = r#"
create table users (
id serial primary key,
name text,
address text,
email text
);
"#;

let test_case = TestCase {
message: "suggests user created tables first",
query: format!(r#"select {} from users"#, CURSOR_POS),
label: "",
description: "",
};

let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await;
let params = get_test_params(&tree, &cache, test_case.get_input_query());
let results = complete(params);

let labels: Vec<String> = items.into_iter().map(|c| c.label).collect();
let (first_four, _rest) = results.split_at(4);

let has_column_in_first_four = |col: &'static str| {
first_four
.iter()
.find(|compl_item| compl_item.label.as_str() == col)
.is_some()
};

assert_eq!(labels, vec!["name", "narrator", "narrator_id"]);
assert!(
has_column_in_first_four("id"),
"`id` not present in first four completion items."
);
assert!(
has_column_in_first_four("name"),
"`name` not present in first four completion items."
);
assert!(
has_column_in_first_four("address"),
"`address` not present in first four completion items."
);
assert!(
has_column_in_first_four("email"),
"`email` not present in first four completion items."
);
}
}
4 changes: 2 additions & 2 deletions crates/pgt_completions/src/providers/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ mod tests {
"#;

let test_cases = vec![
(format!("select * from us{}", CURSOR_POS), "users"),
(format!("select * from em{}", CURSOR_POS), "emails"),
// (format!("select * from us{}", CURSOR_POS), "users"),
// (format!("select * from em{}", CURSOR_POS), "emails"),
(format!("select * from {}", CURSOR_POS), "addresses"),
];

Expand Down
27 changes: 15 additions & 12 deletions crates/pgt_completions/src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ impl CompletionRelevance<'_> {
self.check_is_user_defined();
self.check_matches_schema(ctx);
self.check_matches_query_input(ctx);
self.check_if_catalog(ctx);
self.check_is_invocation(ctx);
self.check_matching_clause_type(ctx);
self.check_relations_in_stmt(ctx);
Expand All @@ -42,7 +41,10 @@ impl CompletionRelevance<'_> {
}

fn check_matches_query_input(&mut self, ctx: &CompletionContext) {
let node = ctx.ts_node.unwrap();
let node = match ctx.node_under_cursor {
Some(node) => node,
None => return,
};

let content = match ctx.get_ts_node_content(node) {
Some(c) => c,
Expand All @@ -52,7 +54,10 @@ impl CompletionRelevance<'_> {
let name = match self.data {
CompletionRelevanceData::Function(f) => f.name.as_str(),
CompletionRelevanceData::Table(t) => t.name.as_str(),
CompletionRelevanceData::Column(c) => c.name.as_str(),
CompletionRelevanceData::Column(c) => {
//
c.name.as_str()
}
};

if name.starts_with(content) {
Expand All @@ -61,7 +66,7 @@ impl CompletionRelevance<'_> {
.try_into()
.expect("The length of the input exceeds i32 capacity");

self.score += len * 5;
self.score += len * 10;
};
}

Expand Down Expand Up @@ -135,14 +140,6 @@ impl CompletionRelevance<'_> {
}
}

fn check_if_catalog(&mut self, ctx: &CompletionContext) {
if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") {
return;
}

self.score -= 5; // unlikely that the user wants schema data
}

fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) {
match self.data {
CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return,
Expand Down Expand Up @@ -182,5 +179,11 @@ impl CompletionRelevance<'_> {
if system_schemas.contains(&schema.as_str()) {
self.score -= 10;
}

// "public" is the default postgres schema where users
// create objects. Prefer it by a slight bit.
if schema.as_str() == "public" {
self.score += 2;
}
}
}
20 changes: 13 additions & 7 deletions crates/pgt_workspace/src/workspace/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pgt_analyse::{AnalyserOptions, AnalysisFilter};
use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext};
use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic};
use pgt_fs::{ConfigName, PgTPath};
use pgt_text_size::{TextRange, TextSize};
use pgt_typecheck::TypecheckParams;
use schema_cache_manager::SchemaCacheManager;
use sqlx::Executor;
Expand Down Expand Up @@ -535,13 +536,18 @@ impl Workspace for WorkspaceServer {
.get(&params.path)
.ok_or(WorkspaceError::not_found())?;

let (statement, stmt_range, text) = match doc
.iter_statements_with_text_and_range()
.find(|(_, r, _)| r.contains(params.position))
{
Some(s) => s,
None => return Ok(CompletionsResult::default()),
};
let (statement, stmt_range, text) =
match doc.iter_statements_with_text_and_range().find(|(_, r, _)| {
let expanded_range = TextRange::new(
r.start(),
r.end().checked_add(TextSize::new(2)).unwrap_or(r.end()),
);

expanded_range.contains(params.position)
}) {
Some(s) => s,
None => return Ok(CompletionsResult::default()),
};

// `offset` is the position in the document,
// but we need the position within the *statement*.
Expand Down
2 changes: 1 addition & 1 deletion postgrestools.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// YOU CAN COMMENT ME OUT :)
"db": {
"host": "127.0.0.1",
"port": 5432,
"port": 54322,
"username": "postgres",
"password": "postgres",
"database": "postgres",
Expand Down
Loading