Skip to content

Commit 9a75604

Browse files
add test
1 parent f594f87 commit 9a75604

File tree

6 files changed

+159
-55
lines changed

6 files changed

+159
-55
lines changed

crates/pgt_completions/src/context.rs

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pgt_treesitter_queries::{
88

99
use crate::sanitization::SanitizedCompletionParams;
1010

11-
#[derive(Debug, PartialEq, Eq)]
11+
#[derive(Debug, PartialEq, Eq, Hash)]
1212
pub enum WrappingClause<'a> {
1313
Select,
1414
Where,
@@ -115,7 +115,7 @@ pub(crate) struct CompletionContext<'a> {
115115

116116
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
117117
pub mentioned_table_aliases: HashMap<String, String>,
118-
pub mentioned_columns: HashSet<MentionedColumn>,
118+
pub mentioned_columns: HashMap<Option<WrappingClause<'a>>, HashSet<MentionedColumn>>,
119119
}
120120

121121
impl<'a> CompletionContext<'a> {
@@ -133,7 +133,7 @@ impl<'a> CompletionContext<'a> {
133133
is_invocation: false,
134134
mentioned_relations: HashMap::new(),
135135
mentioned_table_aliases: HashMap::new(),
136-
mentioned_columns: HashSet::new(),
136+
mentioned_columns: HashMap::new(),
137137
is_in_error_node: false,
138138
};
139139

@@ -151,38 +151,45 @@ impl<'a> CompletionContext<'a> {
151151

152152
executor.add_query_results::<queries::RelationMatch>();
153153
executor.add_query_results::<queries::TableAliasMatch>();
154-
executor.add_query_results::<queries::ColumnMatch>();
154+
executor.add_query_results::<queries::SelectColumnMatch>();
155155

156156
for relation_match in executor.get_iter(stmt_range) {
157157
match relation_match {
158158
QueryResult::Relation(r) => {
159159
let schema_name = r.get_schema(sql);
160160
let table_name = r.get_table(sql);
161161

162-
let current = self.mentioned_relations.get_mut(&schema_name);
163-
164-
match current {
165-
Some(c) => {
166-
c.insert(table_name);
167-
}
168-
None => {
169-
let mut new = HashSet::new();
170-
new.insert(table_name);
171-
self.mentioned_relations.insert(schema_name, new);
172-
}
173-
};
162+
if let Some(c) = self.mentioned_relations.get_mut(&schema_name) {
163+
c.insert(table_name);
164+
} else {
165+
let mut new = HashSet::new();
166+
new.insert(table_name);
167+
self.mentioned_relations.insert(schema_name, new);
168+
}
174169
}
175170
QueryResult::TableAliases(table_alias_match) => {
176171
self.mentioned_table_aliases.insert(
177172
table_alias_match.get_alias(sql),
178173
table_alias_match.get_table(sql),
179174
);
180175
}
181-
QueryResult::Column(c) => {
182-
self.mentioned_columns.insert(MentionedColumn {
176+
QueryResult::SelectClauseColumns(c) => {
177+
let mentioned = MentionedColumn {
183178
column: c.get_column(sql),
184179
alias: c.get_alias(sql),
185-
});
180+
};
181+
182+
if let Some(cols) = self
183+
.mentioned_columns
184+
.get_mut(&Some(WrappingClause::Select))
185+
{
186+
cols.insert(mentioned);
187+
} else {
188+
let mut new = HashSet::new();
189+
new.insert(mentioned);
190+
self.mentioned_columns
191+
.insert(Some(WrappingClause::Select), new);
192+
}
186193
}
187194
};
188195
}

crates/pgt_completions/src/providers/columns.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,4 +484,93 @@ mod tests {
484484
)
485485
.await;
486486
}
487+
488+
#[tokio::test]
489+
async fn prefers_not_mentioned_columns() {
490+
let setup = r#"
491+
create schema auth;
492+
493+
create table public.one (
494+
id serial primary key,
495+
a text,
496+
b text,
497+
z text
498+
);
499+
500+
create table public.two (
501+
id serial primary key,
502+
c text,
503+
d text,
504+
e text
505+
);
506+
"#;
507+
508+
assert_complete_results(
509+
format!(
510+
"select {} from public.one o join public.two on o.id = t.id;",
511+
CURSOR_POS
512+
)
513+
.as_str(),
514+
vec![
515+
CompletionAssertion::Label("a".to_string()),
516+
CompletionAssertion::Label("b".to_string()),
517+
CompletionAssertion::Label("c".to_string()),
518+
CompletionAssertion::Label("d".to_string()),
519+
CompletionAssertion::Label("e".to_string()),
520+
],
521+
setup,
522+
)
523+
.await;
524+
525+
// "a" is already mentioned, so it jumps down
526+
assert_complete_results(
527+
format!(
528+
"select a, {} from public.one o join public.two on o.id = t.id;",
529+
CURSOR_POS
530+
)
531+
.as_str(),
532+
vec![
533+
CompletionAssertion::Label("b".to_string()),
534+
CompletionAssertion::Label("c".to_string()),
535+
CompletionAssertion::Label("d".to_string()),
536+
CompletionAssertion::Label("e".to_string()),
537+
CompletionAssertion::Label("id".to_string()),
538+
CompletionAssertion::Label("z".to_string()),
539+
CompletionAssertion::Label("a".to_string()),
540+
],
541+
setup,
542+
)
543+
.await;
544+
545+
// "id" of table one is mentioned, but table two isn't –
546+
// its priority stays up
547+
assert_complete_results(
548+
format!(
549+
"select o.id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;",
550+
CURSOR_POS
551+
)
552+
.as_str(),
553+
vec![
554+
CompletionAssertion::LabelAndDesc(
555+
"id".to_string(),
556+
"Table: public.two".to_string(),
557+
),
558+
CompletionAssertion::Label("z".to_string()),
559+
],
560+
setup,
561+
)
562+
.await;
563+
564+
// "id" is ambiguous, so both "id" columns are lowered in priority
565+
assert_complete_results(
566+
format!(
567+
"select id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;",
568+
CURSOR_POS
569+
)
570+
.as_str(),
571+
vec![CompletionAssertion::Label("z".to_string())],
572+
setup,
573+
)
574+
.await;
575+
}
487576
}

crates/pgt_completions/src/relevance/scoring.rs

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,6 @@ impl CompletionScore<'_> {
238238
}
239239

240240
fn check_columns_in_stmt(&mut self, ctx: &CompletionContext) {
241-
// we only want to consider mentioned columns in a select statement.
242-
if ctx
243-
.wrapping_clause_type
244-
.as_ref()
245-
.is_none_or(|ct| ct != &WrappingClause::Select)
246-
{
247-
return;
248-
}
249-
250241
match self.data {
251242
CompletionRelevanceData::Column(column) => {
252243
/*
@@ -266,14 +257,16 @@ impl CompletionScore<'_> {
266257
*/
267258
if ctx
268259
.mentioned_columns
269-
.iter()
270-
.any(|mentioned| match mentioned.alias.as_ref() {
271-
Some(als) => {
272-
let aliased_table = ctx.mentioned_table_aliases.get(als.as_str());
273-
column.name == mentioned.column
274-
&& aliased_table.is_none_or(|t| t == &column.table_name)
275-
}
276-
None => mentioned.column == column.name,
260+
.get(&ctx.wrapping_clause_type)
261+
.is_some_and(|set| {
262+
set.iter().any(|mentioned| match mentioned.alias.as_ref() {
263+
Some(als) => {
264+
let aliased_table = ctx.mentioned_table_aliases.get(als.as_str());
265+
column.name == mentioned.column
266+
&& aliased_table.is_none_or(|t| t == &column.table_name)
267+
}
268+
None => mentioned.column == column.name,
269+
})
277270
})
278271
{
279272
self.score -= 10;

crates/pgt_completions/src/test_helper.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ mod tests {
146146
pub(crate) enum CompletionAssertion {
147147
Label(String),
148148
LabelAndKind(String, CompletionItemKind),
149+
LabelAndDesc(String, String),
149150
LabelNotExists(String),
150151
KindNotExists(CompletionItemKind),
151152
}
@@ -186,6 +187,18 @@ impl CompletionAssertion {
186187
kind
187188
);
188189
}
190+
CompletionAssertion::LabelAndDesc(label, desc) => {
191+
assert_eq!(
192+
&item.label, label,
193+
"Expected label to be {}, but got {}",
194+
label, &item.label
195+
);
196+
assert_eq!(
197+
&item.description, desc,
198+
"Expected desc to be {}, but got {}",
199+
desc, &item.description
200+
);
201+
}
189202
}
190203
}
191204
}
@@ -202,7 +215,9 @@ pub(crate) async fn assert_complete_results(
202215
let (not_existing, existing): (Vec<CompletionAssertion>, Vec<CompletionAssertion>) =
203216
assertions.into_iter().partition(|a| match a {
204217
CompletionAssertion::LabelNotExists(_) | CompletionAssertion::KindNotExists(_) => true,
205-
CompletionAssertion::Label(_) | CompletionAssertion::LabelAndKind(_, _) => false,
218+
CompletionAssertion::Label(_)
219+
| CompletionAssertion::LabelAndKind(_, _)
220+
| CompletionAssertion::LabelAndDesc(_, _) => false,
206221
});
207222

208223
assert!(

crates/pgt_treesitter_queries/src/queries/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
mod columns;
21
mod relations;
2+
mod select_columns;
33
mod table_aliases;
44

5-
pub use columns::*;
65
pub use relations::*;
6+
pub use select_columns::*;
77
pub use table_aliases::*;
88

99
#[derive(Debug)]
1010
pub enum QueryResult<'a> {
1111
Relation(RelationMatch<'a>),
1212
TableAliases(TableAliasMatch<'a>),
13-
Column(ColumnMatch<'a>),
13+
SelectClauseColumns(SelectColumnMatch<'a>),
1414
}
1515

1616
impl QueryResult<'_> {
@@ -31,7 +31,7 @@ impl QueryResult<'_> {
3131
let end = m.alias.end_position();
3232
start >= range.start_point && end <= range.end_point
3333
}
34-
Self::Column(cm) => {
34+
Self::SelectClauseColumns(cm) => {
3535
let start = match cm.alias {
3636
Some(n) => n.start_position(),
3737
None => cm.column.start_position(),

crates/pgt_treesitter_queries/src/queries/columns.rs renamed to crates/pgt_treesitter_queries/src/queries/select_columns.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
2121
});
2222

2323
#[derive(Debug)]
24-
pub struct ColumnMatch<'a> {
24+
pub struct SelectColumnMatch<'a> {
2525
pub(crate) alias: Option<tree_sitter::Node<'a>>,
2626
pub(crate) column: tree_sitter::Node<'a>,
2727
}
2828

29-
impl ColumnMatch<'_> {
29+
impl SelectColumnMatch<'_> {
3030
pub fn get_alias(&self, sql: &str) -> Option<String> {
3131
let str = self
3232
.alias
@@ -45,24 +45,24 @@ impl ColumnMatch<'_> {
4545
}
4646
}
4747

48-
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ColumnMatch<'a> {
48+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a SelectColumnMatch<'a> {
4949
type Error = String;
5050

5151
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
5252
match q {
53-
QueryResult::Column(c) => Ok(c),
53+
QueryResult::SelectClauseColumns(c) => Ok(c),
5454

5555
#[allow(unreachable_patterns)]
5656
_ => Err("Invalid QueryResult type".into()),
5757
}
5858
}
5959
}
6060

61-
impl<'a> QueryTryFrom<'a> for ColumnMatch<'a> {
62-
type Ref = &'a ColumnMatch<'a>;
61+
impl<'a> QueryTryFrom<'a> for SelectColumnMatch<'a> {
62+
type Ref = &'a SelectColumnMatch<'a>;
6363
}
6464

65-
impl<'a> Query<'a> for ColumnMatch<'a> {
65+
impl<'a> Query<'a> for SelectColumnMatch<'a> {
6666
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
6767
let mut cursor = tree_sitter::QueryCursor::new();
6868

@@ -73,7 +73,7 @@ impl<'a> Query<'a> for ColumnMatch<'a> {
7373
for m in matches {
7474
if m.captures.len() == 1 {
7575
let capture = m.captures[0].node;
76-
to_return.push(QueryResult::Column(ColumnMatch {
76+
to_return.push(QueryResult::SelectClauseColumns(SelectColumnMatch {
7777
alias: None,
7878
column: capture,
7979
}));
@@ -83,7 +83,7 @@ impl<'a> Query<'a> for ColumnMatch<'a> {
8383
let alias = m.captures[0].node;
8484
let column = m.captures[1].node;
8585

86-
to_return.push(QueryResult::Column(ColumnMatch {
86+
to_return.push(QueryResult::SelectClauseColumns(SelectColumnMatch {
8787
alias: Some(alias),
8888
column,
8989
}));
@@ -98,7 +98,7 @@ impl<'a> Query<'a> for ColumnMatch<'a> {
9898
mod tests {
9999
use crate::TreeSitterQueriesExecutor;
100100

101-
use super::ColumnMatch;
101+
use super::SelectColumnMatch;
102102

103103
#[test]
104104
fn finds_all_columns() {
@@ -111,9 +111,9 @@ mod tests {
111111

112112
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
113113

114-
executor.add_query_results::<ColumnMatch>();
114+
executor.add_query_results::<SelectColumnMatch>();
115115

116-
let results: Vec<&ColumnMatch> = executor
116+
let results: Vec<&SelectColumnMatch> = executor
117117
.get_iter(None)
118118
.filter_map(|q| q.try_into().ok())
119119
.collect();
@@ -150,9 +150,9 @@ from
150150

151151
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
152152

153-
executor.add_query_results::<ColumnMatch>();
153+
executor.add_query_results::<SelectColumnMatch>();
154154

155-
let results: Vec<&ColumnMatch> = executor
155+
let results: Vec<&SelectColumnMatch> = executor
156156
.get_iter(None)
157157
.filter_map(|q| q.try_into().ok())
158158
.collect();

0 commit comments

Comments
 (0)