@@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
5
5
use ra_ide_db:: symbol_index:: SymbolsDatabase ;
6
6
use ra_ide_db:: RootDatabase ;
7
7
use ra_syntax:: ast:: make:: try_expr_from_text;
8
- use ra_syntax:: ast:: { AstToken , Comment , RecordField , RecordLit } ;
9
- use ra_syntax:: { AstNode , SyntaxElement , SyntaxNode } ;
8
+ use ra_syntax:: ast:: {
9
+ ArgList , AstToken , CallExpr , Comment , Expr , MethodCallExpr , RecordField , RecordLit ,
10
+ } ;
11
+ use ra_syntax:: { AstNode , SyntaxElement , SyntaxKind , SyntaxNode } ;
10
12
use ra_text_edit:: { TextEdit , TextEditBuilder } ;
11
13
use rustc_hash:: FxHashMap ;
12
14
use std:: collections:: HashMap ;
13
- use std:: str:: FromStr ;
15
+ use std:: { iter :: once , str:: FromStr } ;
14
16
15
17
#[ derive( Debug , PartialEq ) ]
16
18
pub struct SsrError ( String ) ;
@@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
219
221
)
220
222
}
221
223
224
+ fn check_call_and_method_call (
225
+ pattern : CallExpr ,
226
+ code : MethodCallExpr ,
227
+ placeholders : & [ Var ] ,
228
+ match_ : Match ,
229
+ ) -> Option < Match > {
230
+ let ( pattern_name, pattern_type_args) = if let Some ( Expr :: PathExpr ( path_exr) ) =
231
+ pattern. expr ( )
232
+ {
233
+ let segment = path_exr. path ( ) . and_then ( |p| p. segment ( ) ) ;
234
+ ( segment. as_ref ( ) . and_then ( |s| s. name_ref ( ) ) , segment. and_then ( |s| s. type_arg_list ( ) ) )
235
+ } else {
236
+ ( None , None )
237
+ } ;
238
+ let match_ = check_opt_nodes ( pattern_name, code. name_ref ( ) , placeholders, match_) ?;
239
+ let match_ =
240
+ check_opt_nodes ( pattern_type_args, code. type_arg_list ( ) , placeholders, match_) ?;
241
+ let pattern_args = pattern. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
242
+ let code_args = code. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
243
+ let code_args = once ( code. expr ( ) ?) . chain ( code_args) ;
244
+ check_iter ( pattern_args, code_args, placeholders, match_)
245
+ }
246
+
247
+ fn check_method_call_and_call (
248
+ pattern : MethodCallExpr ,
249
+ code : CallExpr ,
250
+ placeholders : & [ Var ] ,
251
+ match_ : Match ,
252
+ ) -> Option < Match > {
253
+ let ( code_name, code_type_args) = if let Some ( Expr :: PathExpr ( path_exr) ) = code. expr ( ) {
254
+ let segment = path_exr. path ( ) . and_then ( |p| p. segment ( ) ) ;
255
+ ( segment. as_ref ( ) . and_then ( |s| s. name_ref ( ) ) , segment. and_then ( |s| s. type_arg_list ( ) ) )
256
+ } else {
257
+ ( None , None )
258
+ } ;
259
+ let match_ = check_opt_nodes ( pattern. name_ref ( ) , code_name, placeholders, match_) ?;
260
+ let match_ =
261
+ check_opt_nodes ( pattern. type_arg_list ( ) , code_type_args, placeholders, match_) ?;
262
+ let code_args = code. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
263
+ let pattern_args = pattern. syntax ( ) . children ( ) . find_map ( ArgList :: cast) ?. args ( ) ;
264
+ let pattern_args = once ( pattern. expr ( ) ?) . chain ( pattern_args) ;
265
+ check_iter ( pattern_args, code_args, placeholders, match_)
266
+ }
267
+
222
268
fn check_opt_nodes (
223
269
pattern : Option < impl AstNode > ,
224
270
code : Option < impl AstNode > ,
@@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
227
273
) -> Option < Match > {
228
274
match ( pattern, code) {
229
275
( Some ( pattern) , Some ( code) ) => check (
230
- & SyntaxElement :: from ( pattern. syntax ( ) . clone ( ) ) ,
231
- & SyntaxElement :: from ( code. syntax ( ) . clone ( ) ) ,
276
+ & pattern. syntax ( ) . clone ( ) . into ( ) ,
277
+ & code. syntax ( ) . clone ( ) . into ( ) ,
232
278
placeholders,
233
279
match_,
234
280
) ,
@@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
237
283
}
238
284
}
239
285
286
+ fn check_iter < T , I1 , I2 > (
287
+ mut pattern : I1 ,
288
+ mut code : I2 ,
289
+ placeholders : & [ Var ] ,
290
+ match_ : Match ,
291
+ ) -> Option < Match >
292
+ where
293
+ T : AstNode ,
294
+ I1 : Iterator < Item = T > ,
295
+ I2 : Iterator < Item = T > ,
296
+ {
297
+ pattern
298
+ . by_ref ( )
299
+ . zip ( code. by_ref ( ) )
300
+ . fold ( Some ( match_) , |accum, ( a, b) | {
301
+ accum. and_then ( |match_| {
302
+ check (
303
+ & a. syntax ( ) . clone ( ) . into ( ) ,
304
+ & b. syntax ( ) . clone ( ) . into ( ) ,
305
+ placeholders,
306
+ match_,
307
+ )
308
+ } )
309
+ } )
310
+ . filter ( |_| pattern. next ( ) . is_none ( ) && code. next ( ) . is_none ( ) )
311
+ }
312
+
240
313
fn check (
241
314
pattern : & SyntaxElement ,
242
315
code : & SyntaxElement ,
@@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
260
333
( RecordLit :: cast ( pattern. clone ( ) ) , RecordLit :: cast ( code. clone ( ) ) )
261
334
{
262
335
check_record_lit ( pattern, code, placeholders, match_)
336
+ } else if let ( Some ( pattern) , Some ( code) ) =
337
+ ( CallExpr :: cast ( pattern. clone ( ) ) , MethodCallExpr :: cast ( code. clone ( ) ) )
338
+ {
339
+ check_call_and_method_call ( pattern, code, placeholders, match_)
340
+ } else if let ( Some ( pattern) , Some ( code) ) =
341
+ ( MethodCallExpr :: cast ( pattern. clone ( ) ) , CallExpr :: cast ( code. clone ( ) ) )
342
+ {
343
+ check_method_call_and_call ( pattern, code, placeholders, match_)
263
344
} else {
264
345
let mut pattern_children = pattern
265
346
. children_with_tokens ( )
@@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
290
371
let kind = pattern. pattern . kind ( ) ;
291
372
let matches = code
292
373
. descendants ( )
293
- . filter ( |n| n. kind ( ) == kind)
374
+ . filter ( |n| {
375
+ n. kind ( ) == kind
376
+ || ( kind == SyntaxKind :: CALL_EXPR && n. kind ( ) == SyntaxKind :: METHOD_CALL_EXPR )
377
+ || ( kind == SyntaxKind :: METHOD_CALL_EXPR && n. kind ( ) == SyntaxKind :: CALL_EXPR )
378
+ } )
294
379
. filter_map ( |code| {
295
380
let match_ =
296
381
Match { place : code. clone ( ) , binding : HashMap :: new ( ) , ignored_comments : vec ! [ ] } ;
297
- check (
298
- & SyntaxElement :: from ( pattern. pattern . clone ( ) ) ,
299
- & SyntaxElement :: from ( code) ,
300
- & pattern. vars ,
301
- match_,
302
- )
382
+ check ( & pattern. pattern . clone ( ) . into ( ) , & code. into ( ) , & pattern. vars , match_)
303
383
} )
304
384
. collect ( ) ;
305
385
SsrMatches { matches }
@@ -498,4 +578,22 @@ mod tests {
498
578
"fn main() { foo::new(1, 2) }" ,
499
579
)
500
580
}
581
+
582
+ #[ test]
583
+ fn ssr_call_and_method_call ( ) {
584
+ assert_ssr_transform (
585
+ "foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)" ,
586
+ "fn main() { get().bar.foo::<'a>(1); }" ,
587
+ "fn main() { foo2(get().bar, 1); }" ,
588
+ )
589
+ }
590
+
591
+ #[ test]
592
+ fn ssr_method_call_and_call ( ) {
593
+ assert_ssr_transform (
594
+ "$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)" ,
595
+ "fn main() { X::foo::<i32>(x, 1); }" ,
596
+ "fn main() { x.foo2(1); }" ,
597
+ )
598
+ }
501
599
}
0 commit comments