1
+ use either:: Either ;
1
2
use ide_db:: {
2
3
famous_defs:: FamousDefs ,
3
4
syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
4
5
} ;
5
- use itertools:: Itertools ;
6
6
use syntax:: {
7
- ast:: { self , Expr , HasGenericArgs } ,
8
- match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
7
+ ast:: { self , syntax_factory :: SyntaxFactory , HasArgList , HasGenericArgs } ,
8
+ match_ast, AstNode , NodeOrToken , SyntaxKind ,
9
9
} ;
10
10
11
11
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,11 +39,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
39
39
pub ( crate ) fn unwrap_return_type ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
40
40
let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
41
41
let parent = ret_type. syntax ( ) . parent ( ) ?;
42
- let body = match_ast ! {
42
+ let body_expr = match_ast ! {
43
43
match parent {
44
- ast:: Fn ( func) => func. body( ) ?,
44
+ ast:: Fn ( func) => func. body( ) ?. into ( ) ,
45
45
ast:: ClosureExpr ( closure) => match closure. body( ) ? {
46
- Expr :: BlockExpr ( block) => block,
46
+ ast :: Expr :: BlockExpr ( block) => block. into ( ) ,
47
47
// closures require a block when a return type is specified
48
48
_ => return None ,
49
49
} ,
@@ -65,72 +65,110 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) ->
65
65
let happy_type = extract_wrapped_type ( type_ref) ?;
66
66
67
67
acc. add ( kind. assist_id ( ) , kind. label ( ) , type_ref. syntax ( ) . text_range ( ) , |builder| {
68
- let body = ast:: Expr :: BlockExpr ( body) ;
68
+ let mut editor = builder. make_editor ( & parent) ;
69
+ let make = SyntaxFactory :: new ( ) ;
69
70
70
71
let mut exprs_to_unwrap = Vec :: new ( ) ;
71
72
let tail_cb = & mut |e : & _ | tail_cb_impl ( & mut exprs_to_unwrap, e) ;
72
- walk_expr ( & body , & mut |expr| {
73
- if let Expr :: ReturnExpr ( ret_expr) = expr {
73
+ walk_expr ( & body_expr , & mut |expr| {
74
+ if let ast :: Expr :: ReturnExpr ( ret_expr) = expr {
74
75
if let Some ( ret_expr_arg) = & ret_expr. expr ( ) {
75
76
for_each_tail_expr ( ret_expr_arg, tail_cb) ;
76
77
}
77
78
}
78
79
} ) ;
79
- for_each_tail_expr ( & body , tail_cb) ;
80
+ for_each_tail_expr ( & body_expr , tail_cb) ;
80
81
81
82
let is_unit_type = is_unit_type ( & happy_type) ;
82
83
if is_unit_type {
83
- let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
84
-
85
84
if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
86
85
if token. kind ( ) == SyntaxKind :: WHITESPACE {
87
- text_range = TextRange :: new ( text_range . start ( ) , token. text_range ( ) . end ( ) ) ;
86
+ editor . delete ( token) ;
88
87
}
89
88
}
90
89
91
- builder . delete ( text_range ) ;
90
+ editor . delete ( ret_type . syntax ( ) ) ;
92
91
} else {
93
- builder . replace ( type_ref. syntax ( ) . text_range ( ) , happy_type. syntax ( ) . text ( ) ) ;
92
+ editor . replace ( type_ref. syntax ( ) , happy_type. syntax ( ) ) ;
94
93
}
95
94
96
- for ret_expr_arg in exprs_to_unwrap {
97
- let ret_expr_str = ret_expr_arg. to_string ( ) ;
98
-
99
- let needs_replacing = match kind {
100
- UnwrapperKind :: Option => ret_expr_str. starts_with ( "Some(" ) ,
101
- UnwrapperKind :: Result => {
102
- ret_expr_str. starts_with ( "Ok(" ) || ret_expr_str. starts_with ( "Err(" )
103
- }
104
- } ;
95
+ let mut final_placeholder = None ;
96
+ for tail_expr in exprs_to_unwrap {
97
+ match & tail_expr {
98
+ ast:: Expr :: CallExpr ( call_expr) => {
99
+ let ast:: Expr :: PathExpr ( path_expr) = call_expr. expr ( ) . unwrap ( ) else {
100
+ continue ;
101
+ } ;
102
+
103
+ let path_str = path_expr. path ( ) . unwrap ( ) . to_string ( ) ;
104
+ let needs_replacing = match kind {
105
+ UnwrapperKind :: Option => path_str == "Some" ,
106
+ UnwrapperKind :: Result => path_str == "Ok" || path_str == "Err" ,
107
+ } ;
108
+
109
+ if !needs_replacing {
110
+ continue ;
111
+ }
105
112
106
- if needs_replacing {
107
- let arg_list = ret_expr_arg. syntax ( ) . children ( ) . find_map ( ast:: ArgList :: cast) ;
108
- if let Some ( arg_list) = arg_list {
113
+ let arg_list = call_expr. arg_list ( ) . unwrap ( ) ;
109
114
if is_unit_type {
110
- match ret_expr_arg . syntax ( ) . prev_sibling_or_token ( ) {
111
- // Useful to delete the entire line without leaving trailing whitespaces
112
- Some ( whitespace ) => {
113
- let new_range = TextRange :: new (
114
- whitespace . text_range ( ) . start ( ) ,
115
- ret_expr_arg . syntax ( ) . text_range ( ) . end ( ) ,
116
- ) ;
117
- builder . delete ( new_range ) ;
115
+ let tail_parent = tail_expr
116
+ . syntax ( )
117
+ . parent ( )
118
+ . and_then ( Either :: < ast :: ReturnExpr , ast :: StmtList > :: cast )
119
+ . unwrap ( ) ;
120
+ match tail_parent {
121
+ Either :: Left ( ret_expr ) => {
122
+ editor . replace ( ret_expr . syntax ( ) , make . expr_return ( None ) . syntax ( ) )
118
123
}
119
- None => {
120
- builder. delete ( ret_expr_arg. syntax ( ) . text_range ( ) ) ;
124
+ Either :: Right ( stmt_list) => {
125
+ let new_block = if stmt_list. statements ( ) . next ( ) . is_none ( ) {
126
+ make. expr_empty_block ( )
127
+ } else {
128
+ make. block_expr ( stmt_list. statements ( ) , None )
129
+ } ;
130
+ editor. replace (
131
+ stmt_list. syntax ( ) ,
132
+ new_block. stmt_list ( ) . unwrap ( ) . syntax ( ) ,
133
+ ) ;
121
134
}
122
135
}
123
- } else {
124
- builder. replace (
125
- ret_expr_arg. syntax ( ) . text_range ( ) ,
126
- arg_list. args ( ) . join ( ", " ) ,
136
+ } else if let Some ( first_arg) = arg_list. args ( ) . next ( ) {
137
+ editor. replace ( tail_expr. syntax ( ) , first_arg. syntax ( ) ) ;
138
+ }
139
+ }
140
+ ast:: Expr :: PathExpr ( path_expr) => {
141
+ let UnwrapperKind :: Option = kind else {
142
+ continue ;
143
+ } ;
144
+
145
+ if path_expr. path ( ) . unwrap ( ) . to_string ( ) != "None" {
146
+ continue ;
147
+ }
148
+
149
+ let new_tail_expr = make. expr_unit ( ) ;
150
+ editor. replace ( path_expr. syntax ( ) , new_tail_expr. syntax ( ) ) ;
151
+ if let Some ( cap) = ctx. config . snippet_cap {
152
+ editor. add_annotation (
153
+ new_tail_expr. syntax ( ) ,
154
+ builder. make_placeholder_snippet ( cap) ,
127
155
) ;
156
+
157
+ final_placeholder = Some ( new_tail_expr) ;
128
158
}
129
159
}
130
- } else if matches ! ( kind, UnwrapperKind :: Option if ret_expr_str == "None" ) {
131
- builder. replace ( ret_expr_arg. syntax ( ) . text_range ( ) , "()" ) ;
160
+ _ => ( ) ,
132
161
}
133
162
}
163
+
164
+ if let Some ( cap) = ctx. config . snippet_cap {
165
+ if let Some ( final_placeholder) = final_placeholder {
166
+ editor. add_annotation ( final_placeholder. syntax ( ) , builder. make_tabstop_after ( cap) ) ;
167
+ }
168
+ }
169
+
170
+ editor. add_mappings ( make. finish_with_mappings ( ) ) ;
171
+ builder. add_file_edits ( ctx. file_id ( ) , editor) ;
134
172
} )
135
173
}
136
174
@@ -168,12 +206,12 @@ impl UnwrapperKind {
168
206
169
207
fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
170
208
match e {
171
- Expr :: BreakExpr ( break_expr) => {
209
+ ast :: Expr :: BreakExpr ( break_expr) => {
172
210
if let Some ( break_expr_arg) = break_expr. expr ( ) {
173
211
for_each_tail_expr ( & break_expr_arg, & mut |e| tail_cb_impl ( acc, e) )
174
212
}
175
213
}
176
- Expr :: ReturnExpr ( _) => {
214
+ ast :: Expr :: ReturnExpr ( _) => {
177
215
// all return expressions have already been handled by the walk loop
178
216
}
179
217
e => acc. push ( e. clone ( ) ) ,
@@ -238,8 +276,7 @@ fn foo() -> Option<()$0> {
238
276
}
239
277
"# ,
240
278
r#"
241
- fn foo() {
242
- }
279
+ fn foo() {}
243
280
"# ,
244
281
"Unwrap Option return type" ,
245
282
) ;
@@ -254,8 +291,7 @@ fn foo() -> Option<()$0>{
254
291
}
255
292
"# ,
256
293
r#"
257
- fn foo() {
258
- }
294
+ fn foo() {}
259
295
"# ,
260
296
"Unwrap Option return type" ,
261
297
) ;
@@ -280,7 +316,42 @@ fn foo() -> i32 {
280
316
if true {
281
317
42
282
318
} else {
283
- ()
319
+ ${1:()}$0
320
+ }
321
+ }
322
+ "# ,
323
+ "Unwrap Option return type" ,
324
+ ) ;
325
+ }
326
+
327
+ #[ test]
328
+ fn unwrap_option_return_type_multi_none ( ) {
329
+ check_assist_by_label (
330
+ unwrap_return_type,
331
+ r#"
332
+ //- minicore: option
333
+ fn foo() -> Option<i3$02> {
334
+ if false {
335
+ return None;
336
+ }
337
+
338
+ if true {
339
+ Some(42)
340
+ } else {
341
+ None
342
+ }
343
+ }
344
+ "# ,
345
+ r#"
346
+ fn foo() -> i32 {
347
+ if false {
348
+ return ${1:()};
349
+ }
350
+
351
+ if true {
352
+ 42
353
+ } else {
354
+ ${2:()}$0
284
355
}
285
356
}
286
357
"# ,
@@ -1262,8 +1333,7 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
1262
1333
}
1263
1334
"# ,
1264
1335
r#"
1265
- fn foo() {
1266
- }
1336
+ fn foo() {}
1267
1337
"# ,
1268
1338
"Unwrap Result return type" ,
1269
1339
) ;
@@ -1278,8 +1348,7 @@ fn foo() -> Result<(), Box<dyn Error$0>>{
1278
1348
}
1279
1349
"# ,
1280
1350
r#"
1281
- fn foo() {
1282
- }
1351
+ fn foo() {}
1283
1352
"# ,
1284
1353
"Unwrap Result return type" ,
1285
1354
) ;
0 commit comments