1
+ use either:: Either ;
1
2
use ide_db:: {
2
3
famous_defs:: FamousDefs ,
3
4
syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
4
5
} ;
5
- use itertools:: Itertools ;
6
6
use syntax:: {
7
- ast:: { self , Expr , HasGenericArgs } ,
8
- match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
7
+ ast:: { self , syntax_factory :: SyntaxFactory , HasArgList , HasGenericArgs } ,
8
+ match_ast, AstNode , NodeOrToken , SyntaxKind ,
9
9
} ;
10
10
11
11
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,11 +39,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
39
39
pub ( crate ) fn unwrap_return_type ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
40
40
let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
41
41
let parent = ret_type. syntax ( ) . parent ( ) ?;
42
- let body = match_ast ! {
42
+ let body_expr = match_ast ! {
43
43
match parent {
44
- ast:: Fn ( func) => func. body( ) ?,
44
+ ast:: Fn ( func) => func. body( ) ?. into ( ) ,
45
45
ast:: ClosureExpr ( closure) => match closure. body( ) ? {
46
- Expr :: BlockExpr ( block) => block,
46
+ ast :: Expr :: BlockExpr ( block) => block. into ( ) ,
47
47
// closures require a block when a return type is specified
48
48
_ => return None ,
49
49
} ,
@@ -65,72 +65,96 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) ->
65
65
let happy_type = extract_wrapped_type ( type_ref) ?;
66
66
67
67
acc. add ( kind. assist_id ( ) , kind. label ( ) , type_ref. syntax ( ) . text_range ( ) , |builder| {
68
- let body = ast:: Expr :: BlockExpr ( body) ;
68
+ let mut editor = builder. make_editor ( & parent) ;
69
+ let make = SyntaxFactory :: new ( ) ;
69
70
70
71
let mut exprs_to_unwrap = Vec :: new ( ) ;
71
72
let tail_cb = & mut |e : & _ | tail_cb_impl ( & mut exprs_to_unwrap, e) ;
72
- walk_expr ( & body , & mut |expr| {
73
- if let Expr :: ReturnExpr ( ret_expr) = expr {
73
+ walk_expr ( & body_expr , & mut |expr| {
74
+ if let ast :: Expr :: ReturnExpr ( ret_expr) = expr {
74
75
if let Some ( ret_expr_arg) = & ret_expr. expr ( ) {
75
76
for_each_tail_expr ( ret_expr_arg, tail_cb) ;
76
77
}
77
78
}
78
79
} ) ;
79
- for_each_tail_expr ( & body , tail_cb) ;
80
+ for_each_tail_expr ( & body_expr , tail_cb) ;
80
81
81
82
let is_unit_type = is_unit_type ( & happy_type) ;
82
83
if is_unit_type {
83
- let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
84
-
85
84
if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
86
85
if token. kind ( ) == SyntaxKind :: WHITESPACE {
87
- text_range = TextRange :: new ( text_range . start ( ) , token. text_range ( ) . end ( ) ) ;
86
+ editor . delete ( token) ;
88
87
}
89
88
}
90
89
91
- builder . delete ( text_range ) ;
90
+ editor . delete ( ret_type . syntax ( ) ) ;
92
91
} else {
93
- builder. replace ( type_ref. syntax ( ) . text_range ( ) , happy_type. syntax ( ) . text ( ) ) ;
92
+ // FIXME: remove `clone_for_update` when `SyntaxEditor` handles it for us
93
+ editor. replace ( type_ref. syntax ( ) , happy_type. syntax ( ) . clone_for_update ( ) ) ;
94
94
}
95
95
96
- for ret_expr_arg in exprs_to_unwrap {
97
- let ret_expr_str = ret_expr_arg. to_string ( ) ;
98
-
99
- let needs_replacing = match kind {
100
- UnwrapperKind :: Option => ret_expr_str. starts_with ( "Some(" ) ,
101
- UnwrapperKind :: Result => {
102
- ret_expr_str. starts_with ( "Ok(" ) || ret_expr_str. starts_with ( "Err(" )
103
- }
104
- } ;
96
+ for tail_expr in exprs_to_unwrap {
97
+ match & tail_expr {
98
+ ast:: Expr :: CallExpr ( call_expr) => {
99
+ let ast:: Expr :: PathExpr ( path_expr) = call_expr. expr ( ) . unwrap ( ) else {
100
+ continue ;
101
+ } ;
102
+
103
+ let path_str = path_expr. path ( ) . unwrap ( ) . to_string ( ) ;
104
+ let needs_replacing = match kind {
105
+ UnwrapperKind :: Option => path_str == "Some" ,
106
+ UnwrapperKind :: Result => path_str == "Ok" || path_str == "Err" ,
107
+ } ;
108
+
109
+ if !needs_replacing {
110
+ continue ;
111
+ }
105
112
106
- if needs_replacing {
107
- let arg_list = ret_expr_arg. syntax ( ) . children ( ) . find_map ( ast:: ArgList :: cast) ;
108
- if let Some ( arg_list) = arg_list {
113
+ let arg_list = call_expr. arg_list ( ) . unwrap ( ) ;
109
114
if is_unit_type {
110
- match ret_expr_arg . syntax ( ) . prev_sibling_or_token ( ) {
111
- // Useful to delete the entire line without leaving trailing whitespaces
112
- Some ( whitespace ) => {
113
- let new_range = TextRange :: new (
114
- whitespace . text_range ( ) . start ( ) ,
115
- ret_expr_arg . syntax ( ) . text_range ( ) . end ( ) ,
116
- ) ;
117
- builder . delete ( new_range ) ;
115
+ let tail_parent = tail_expr
116
+ . syntax ( )
117
+ . parent ( )
118
+ . and_then ( Either :: < ast :: ReturnExpr , ast :: StmtList > :: cast )
119
+ . unwrap ( ) ;
120
+ match tail_parent {
121
+ Either :: Left ( ret_expr ) => {
122
+ editor . replace ( ret_expr . syntax ( ) , make . expr_return ( None ) . syntax ( ) )
118
123
}
119
- None => {
120
- builder. delete ( ret_expr_arg. syntax ( ) . text_range ( ) ) ;
124
+ Either :: Right ( stmt_list) => {
125
+ let new_block = if stmt_list. statements ( ) . next ( ) . is_none ( ) {
126
+ make. expr_empty_block ( )
127
+ } else {
128
+ make. block_expr ( stmt_list. statements ( ) , None )
129
+ } ;
130
+ editor. replace (
131
+ stmt_list. syntax ( ) ,
132
+ new_block. stmt_list ( ) . unwrap ( ) . syntax ( ) ,
133
+ ) ;
121
134
}
122
135
}
123
- } else {
124
- builder. replace (
125
- ret_expr_arg. syntax ( ) . text_range ( ) ,
126
- arg_list. args ( ) . join ( ", " ) ,
127
- ) ;
136
+ } else if let Some ( first_arg) = arg_list. args ( ) . next ( ) {
137
+ // FIXME: remove `clone_for_update` when `SyntaxEditor` handles it for us
138
+ editor. replace ( tail_expr. syntax ( ) , first_arg. syntax ( ) . clone_for_update ( ) ) ;
128
139
}
129
140
}
130
- } else if matches ! ( kind, UnwrapperKind :: Option if ret_expr_str == "None" ) {
131
- builder. replace ( ret_expr_arg. syntax ( ) . text_range ( ) , "()" ) ;
141
+ ast:: Expr :: PathExpr ( path_expr) => {
142
+ let UnwrapperKind :: Option = kind else {
143
+ continue ;
144
+ } ;
145
+
146
+ if path_expr. path ( ) . unwrap ( ) . to_string ( ) != "None" {
147
+ continue ;
148
+ }
149
+
150
+ editor. replace ( path_expr. syntax ( ) , make. expr_unit ( ) . syntax ( ) ) ;
151
+ }
152
+ _ => ( ) ,
132
153
}
133
154
}
155
+
156
+ editor. add_mappings ( make. finish_with_mappings ( ) ) ;
157
+ builder. add_file_edits ( ctx. file_id ( ) , editor) ;
134
158
} )
135
159
}
136
160
@@ -168,12 +192,12 @@ impl UnwrapperKind {
168
192
169
193
fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
170
194
match e {
171
- Expr :: BreakExpr ( break_expr) => {
195
+ ast :: Expr :: BreakExpr ( break_expr) => {
172
196
if let Some ( break_expr_arg) = break_expr. expr ( ) {
173
197
for_each_tail_expr ( & break_expr_arg, & mut |e| tail_cb_impl ( acc, e) )
174
198
}
175
199
}
176
- Expr :: ReturnExpr ( _) => {
200
+ ast :: Expr :: ReturnExpr ( _) => {
177
201
// all return expressions have already been handled by the walk loop
178
202
}
179
203
e => acc. push ( e. clone ( ) ) ,
@@ -238,8 +262,7 @@ fn foo() -> Option<()$0> {
238
262
}
239
263
"# ,
240
264
r#"
241
- fn foo() {
242
- }
265
+ fn foo() {}
243
266
"# ,
244
267
"Unwrap Option return type" ,
245
268
) ;
@@ -254,8 +277,7 @@ fn foo() -> Option<()$0>{
254
277
}
255
278
"# ,
256
279
r#"
257
- fn foo() {
258
- }
280
+ fn foo() {}
259
281
"# ,
260
282
"Unwrap Option return type" ,
261
283
) ;
@@ -1262,8 +1284,7 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
1262
1284
}
1263
1285
"# ,
1264
1286
r#"
1265
- fn foo() {
1266
- }
1287
+ fn foo() {}
1267
1288
"# ,
1268
1289
"Unwrap Result return type" ,
1269
1290
) ;
@@ -1278,8 +1299,7 @@ fn foo() -> Result<(), Box<dyn Error$0>>{
1278
1299
}
1279
1300
"# ,
1280
1301
r#"
1281
- fn foo() {
1282
- }
1302
+ fn foo() {}
1283
1303
"# ,
1284
1304
"Unwrap Result return type" ,
1285
1305
) ;
0 commit comments