Skip to content

Commit 9415a5d

Browse files
feat: clause_type enum
1 parent 458e433 commit 9415a5d

File tree

6 files changed

+151
-25
lines changed

6 files changed

+151
-25
lines changed

crates/pg_completions/src/complete.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use text_size::TextSize;
22

33
use crate::{
4-
builder::CompletionBuilder, context::CompletionContext, item::CompletionItem,
5-
providers::complete_tables,
4+
builder::CompletionBuilder,
5+
context::CompletionContext,
6+
item::CompletionItem,
7+
providers::{complete_functions, complete_tables},
68
};
79

810
pub const LIMIT: usize = 50;
@@ -34,6 +36,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult {
3436
let mut builder = CompletionBuilder::new();
3537

3638
complete_tables(&ctx, &mut builder);
39+
complete_functions(&ctx, &mut builder);
3740

3841
builder.finish()
3942
}

crates/pg_completions/src/context.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,30 @@ use pg_schema_cache::SchemaCache;
22

33
use crate::CompletionParams;
44

5+
#[derive(Debug, PartialEq, Eq)]
6+
pub enum ClauseType {
7+
Select,
8+
Where,
9+
From,
10+
}
11+
12+
impl From<&str> for ClauseType {
13+
fn from(value: &str) -> Self {
14+
match value {
15+
"select" => Self::Select,
16+
"where" => Self::Where,
17+
"from" => Self::From,
18+
_ => panic!("Unimplemented ClauseType: {}", value),
19+
}
20+
}
21+
}
22+
23+
impl From<String> for ClauseType {
24+
fn from(value: String) -> Self {
25+
ClauseType::from(value.as_str())
26+
}
27+
}
28+
529
pub(crate) struct CompletionContext<'a> {
630
pub ts_node: Option<tree_sitter::Node<'a>>,
731
pub tree: Option<&'a tree_sitter::Tree>,
@@ -10,7 +34,7 @@ pub(crate) struct CompletionContext<'a> {
1034
pub position: usize,
1135

1236
pub schema_name: Option<String>,
13-
pub wrapping_clause_type: Option<String>,
37+
pub wrapping_clause_type: Option<ClauseType>,
1438
pub is_invocation: bool,
1539
}
1640

@@ -65,7 +89,7 @@ impl<'a> CompletionContext<'a> {
6589
let current_node_kind = current_node.kind();
6690

6791
match previous_node_kind {
68-
"statement" => self.wrapping_clause_type = Some(current_node_kind.to_string()),
92+
"statement" => self.wrapping_clause_type = Some(current_node_kind.into()),
6993
"invocation" => self.is_invocation = true,
7094

7195
_ => {}
@@ -84,7 +108,7 @@ impl<'a> CompletionContext<'a> {
84108

85109
// in Treesitter, the Where clause is nested inside other clauses
86110
"where" => {
87-
self.wrapping_clause_type = Some("where".to_string());
111+
self.wrapping_clause_type = Some("where".into());
88112
}
89113

90114
_ => {}
@@ -156,7 +180,7 @@ mod tests {
156180

157181
let ctx = CompletionContext::new(&params);
158182

159-
assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.to_string()));
183+
assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.into()));
160184
}
161185
}
162186

crates/pg_completions/src/item.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#[derive(Debug)]
1+
#[derive(Debug, PartialEq, Eq)]
22
pub enum CompletionItemKind {
33
Table,
44
Function,

crates/pg_completions/src/providers/functions.rs

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use pg_schema_cache::Function;
2-
31
use crate::{
42
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
53
CompletionItem, CompletionItemKind,
@@ -27,10 +25,9 @@ pub fn complete_functions(ctx: &CompletionContext, builder: &mut CompletionBuild
2725
#[cfg(test)]
2826
mod tests {
2927
use crate::{
30-
context::CompletionContext,
31-
providers::complete_functions,
28+
complete,
3229
test_helper::{get_test_deps, get_test_params, CURSOR_POS},
33-
CompletionItem,
30+
CompletionItem, CompletionItemKind,
3431
};
3532

3633
#[tokio::test]
@@ -49,19 +46,83 @@ mod tests {
4946

5047
let query = format!("select coo{}", CURSOR_POS);
5148

52-
let (tree, cache, mut builder) = get_test_deps(setup, &query).await;
49+
let (tree, cache) = get_test_deps(setup, &query).await;
5350
let params = get_test_params(&tree, &cache, &query);
54-
let ctx = CompletionContext::new(&params);
51+
let results = complete(params);
5552

56-
complete_functions(&ctx, &mut builder);
53+
let CompletionItem { label, .. } = results
54+
.into_iter()
55+
.next()
56+
.expect("Should return at least one completion item");
5757

58-
let results = builder.finish();
58+
assert_eq!(label, "cool");
59+
}
5960

60-
let CompletionItem { label, .. } = results
61+
#[tokio::test]
62+
async fn prefers_fn_if_invocation() {
63+
let setup = r#"
64+
create table coos (
65+
id serial primary key,
66+
name text
67+
);
68+
69+
create or replace function cool()
70+
returns trigger
71+
language plpgsql
72+
security invoker
73+
as $$
74+
begin
75+
raise exception 'dont matter';
76+
end;
77+
$$;
78+
"#;
79+
80+
let query = format!(r#"select * from coo{}()"#, CURSOR_POS);
81+
82+
let (tree, cache) = get_test_deps(setup, &query).await;
83+
let params = get_test_params(&tree, &cache, &query);
84+
let results = complete(params);
85+
86+
let CompletionItem { label, kind, .. } = results
87+
.into_iter()
88+
.next()
89+
.expect("Should return at least one completion item");
90+
91+
assert_eq!(label, "cool");
92+
assert_eq!(kind, CompletionItemKind::Function);
93+
}
94+
95+
#[tokio::test]
96+
async fn prefers_fn_in_select_clause() {
97+
let setup = r#"
98+
create table coos (
99+
id serial primary key,
100+
name text
101+
);
102+
103+
create or replace function cool()
104+
returns trigger
105+
language plpgsql
106+
security invoker
107+
as $$
108+
begin
109+
raise exception 'dont matter';
110+
end;
111+
$$;
112+
"#;
113+
114+
let query = format!(r#"select coo{}"#, CURSOR_POS);
115+
116+
let (tree, cache) = get_test_deps(setup, &query).await;
117+
let params = get_test_params(&tree, &cache, &query);
118+
let results = complete(params);
119+
120+
let CompletionItem { label, kind, .. } = results
61121
.into_iter()
62122
.next()
63123
.expect("Should return at least one completion item");
64124

65125
assert_eq!(label, "cool");
126+
assert_eq!(kind, CompletionItemKind::Function);
66127
}
67128
}

crates/pg_completions/src/providers/tables.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,46 @@ pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder)
2323
builder.add_item(item);
2424
}
2525
}
26+
27+
mod tests {
28+
use crate::{
29+
complete,
30+
test_helper::{get_test_deps, get_test_params, CURSOR_POS},
31+
CompletionItem, CompletionItemKind,
32+
};
33+
34+
#[tokio::test]
35+
async fn prefers_table_in_from_clause() {
36+
let setup = r#"
37+
create table coos (
38+
id serial primary key,
39+
name text
40+
);
41+
42+
create or replace function cool()
43+
returns trigger
44+
language plpgsql
45+
security invoker
46+
as $$
47+
begin
48+
raise exception 'dont matter';
49+
end;
50+
$$;
51+
"#;
52+
53+
let query = format!(r#"select * from coo{}"#, CURSOR_POS);
54+
55+
let (tree, cache) = get_test_deps(setup, &query).await;
56+
let params = get_test_params(&tree, &cache, &query);
57+
58+
let results = complete(params);
59+
60+
let CompletionItem { label, kind, .. } = results
61+
.into_iter()
62+
.next()
63+
.expect("Should return at least one completion item");
64+
65+
assert_eq!(label, "coos");
66+
assert_eq!(kind, CompletionItemKind::Table);
67+
}
68+
}

crates/pg_completions/src/test_helper.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@ use pg_schema_cache::SchemaCache;
22
use pg_test_utils::test_database::get_new_test_db;
33
use sqlx::Executor;
44

5-
use crate::{builder::CompletionBuilder, CompletionParams};
5+
use crate::CompletionParams;
66

77
pub static CURSOR_POS: &str = "XXX";
88

99
pub(crate) async fn get_test_deps(
1010
setup: &str,
1111
input: &str,
12-
) -> (
13-
tree_sitter::Tree,
14-
pg_schema_cache::SchemaCache,
15-
CompletionBuilder,
16-
) {
12+
) -> (tree_sitter::Tree, pg_schema_cache::SchemaCache) {
1713
let test_db = get_new_test_db().await;
1814

1915
test_db
@@ -29,9 +25,8 @@ pub(crate) async fn get_test_deps(
2925
.expect("Error loading sql language");
3026

3127
let tree = parser.parse(input, None).unwrap();
32-
let builder = CompletionBuilder::new();
3328

34-
(tree, schema_cache, builder)
29+
(tree, schema_cache)
3530
}
3631

3732
pub(crate) fn get_test_params<'a>(

0 commit comments

Comments
 (0)