Skip to content

Commit 4f904b2

Browse files
Merge #3829
3829: Adds to SSR match for semantically equivalent call and method call r=matklad a=mikhail-m1 #3186 maybe I've missed some corner cases, but it works in general Co-authored-by: Mikhail Modin <[email protected]>
2 parents a93a04f + 35a2cd0 commit 4f904b2

File tree

1 file changed

+110
-12
lines changed

1 file changed

+110
-12
lines changed

crates/ra_ide/src/ssr.rs

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
55
use ra_ide_db::symbol_index::SymbolsDatabase;
66
use ra_ide_db::RootDatabase;
77
use ra_syntax::ast::make::try_expr_from_text;
8-
use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit};
9-
use ra_syntax::{AstNode, SyntaxElement, SyntaxNode};
8+
use ra_syntax::ast::{
9+
ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit,
10+
};
11+
use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode};
1012
use ra_text_edit::{TextEdit, TextEditBuilder};
1113
use rustc_hash::FxHashMap;
1214
use std::collections::HashMap;
13-
use std::str::FromStr;
15+
use std::{iter::once, str::FromStr};
1416

1517
#[derive(Debug, PartialEq)]
1618
pub struct SsrError(String);
@@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
219221
)
220222
}
221223

224+
fn check_call_and_method_call(
225+
pattern: CallExpr,
226+
code: MethodCallExpr,
227+
placeholders: &[Var],
228+
match_: Match,
229+
) -> Option<Match> {
230+
let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) =
231+
pattern.expr()
232+
{
233+
let segment = path_exr.path().and_then(|p| p.segment());
234+
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
235+
} else {
236+
(None, None)
237+
};
238+
let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?;
239+
let match_ =
240+
check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?;
241+
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
242+
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
243+
let code_args = once(code.expr()?).chain(code_args);
244+
check_iter(pattern_args, code_args, placeholders, match_)
245+
}
246+
247+
fn check_method_call_and_call(
248+
pattern: MethodCallExpr,
249+
code: CallExpr,
250+
placeholders: &[Var],
251+
match_: Match,
252+
) -> Option<Match> {
253+
let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() {
254+
let segment = path_exr.path().and_then(|p| p.segment());
255+
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
256+
} else {
257+
(None, None)
258+
};
259+
let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?;
260+
let match_ =
261+
check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?;
262+
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
263+
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
264+
let pattern_args = once(pattern.expr()?).chain(pattern_args);
265+
check_iter(pattern_args, code_args, placeholders, match_)
266+
}
267+
222268
fn check_opt_nodes(
223269
pattern: Option<impl AstNode>,
224270
code: Option<impl AstNode>,
@@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
227273
) -> Option<Match> {
228274
match (pattern, code) {
229275
(Some(pattern), Some(code)) => check(
230-
&SyntaxElement::from(pattern.syntax().clone()),
231-
&SyntaxElement::from(code.syntax().clone()),
276+
&pattern.syntax().clone().into(),
277+
&code.syntax().clone().into(),
232278
placeholders,
233279
match_,
234280
),
@@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
237283
}
238284
}
239285

286+
fn check_iter<T, I1, I2>(
287+
mut pattern: I1,
288+
mut code: I2,
289+
placeholders: &[Var],
290+
match_: Match,
291+
) -> Option<Match>
292+
where
293+
T: AstNode,
294+
I1: Iterator<Item = T>,
295+
I2: Iterator<Item = T>,
296+
{
297+
pattern
298+
.by_ref()
299+
.zip(code.by_ref())
300+
.fold(Some(match_), |accum, (a, b)| {
301+
accum.and_then(|match_| {
302+
check(
303+
&a.syntax().clone().into(),
304+
&b.syntax().clone().into(),
305+
placeholders,
306+
match_,
307+
)
308+
})
309+
})
310+
.filter(|_| pattern.next().is_none() && code.next().is_none())
311+
}
312+
240313
fn check(
241314
pattern: &SyntaxElement,
242315
code: &SyntaxElement,
@@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
260333
(RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone()))
261334
{
262335
check_record_lit(pattern, code, placeholders, match_)
336+
} else if let (Some(pattern), Some(code)) =
337+
(CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone()))
338+
{
339+
check_call_and_method_call(pattern, code, placeholders, match_)
340+
} else if let (Some(pattern), Some(code)) =
341+
(MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone()))
342+
{
343+
check_method_call_and_call(pattern, code, placeholders, match_)
263344
} else {
264345
let mut pattern_children = pattern
265346
.children_with_tokens()
@@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
290371
let kind = pattern.pattern.kind();
291372
let matches = code
292373
.descendants()
293-
.filter(|n| n.kind() == kind)
374+
.filter(|n| {
375+
n.kind() == kind
376+
|| (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR)
377+
|| (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR)
378+
})
294379
.filter_map(|code| {
295380
let match_ =
296381
Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] };
297-
check(
298-
&SyntaxElement::from(pattern.pattern.clone()),
299-
&SyntaxElement::from(code),
300-
&pattern.vars,
301-
match_,
302-
)
382+
check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_)
303383
})
304384
.collect();
305385
SsrMatches { matches }
@@ -498,4 +578,22 @@ mod tests {
498578
"fn main() { foo::new(1, 2) }",
499579
)
500580
}
581+
582+
#[test]
583+
fn ssr_call_and_method_call() {
584+
assert_ssr_transform(
585+
"foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)",
586+
"fn main() { get().bar.foo::<'a>(1); }",
587+
"fn main() { foo2(get().bar, 1); }",
588+
)
589+
}
590+
591+
#[test]
592+
fn ssr_method_call_and_call() {
593+
assert_ssr_transform(
594+
"$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)",
595+
"fn main() { X::foo::<i32>(x, 1); }",
596+
"fn main() { x.foo2(1); }",
597+
)
598+
}
501599
}

0 commit comments

Comments
 (0)