1
+ use either:: Either ;
1
2
use ide_db:: {
2
3
famous_defs:: FamousDefs ,
3
4
syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
4
5
} ;
5
- use itertools:: Itertools ;
6
6
use syntax:: {
7
- ast:: { self , Expr , HasGenericArgs } ,
8
- match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
7
+ ast:: { self , syntax_factory :: SyntaxFactory , HasArgList , HasGenericArgs } ,
8
+ match_ast, AstNode , NodeOrToken , SyntaxKind ,
9
9
} ;
10
10
11
11
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,11 +39,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
39
39
pub ( crate ) fn unwrap_return_type ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
40
40
let ret_type = ctx. find_node_at_offset :: < ast:: RetType > ( ) ?;
41
41
let parent = ret_type. syntax ( ) . parent ( ) ?;
42
- let body = match_ast ! {
42
+ let body_expr = match_ast ! {
43
43
match parent {
44
- ast:: Fn ( func) => func. body( ) ?,
44
+ ast:: Fn ( func) => func. body( ) ?. into ( ) ,
45
45
ast:: ClosureExpr ( closure) => match closure. body( ) ? {
46
- Expr :: BlockExpr ( block) => block,
46
+ ast :: Expr :: BlockExpr ( block) => block. into ( ) ,
47
47
// closures require a block when a return type is specified
48
48
_ => return None ,
49
49
} ,
@@ -65,72 +65,94 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) ->
65
65
let happy_type = extract_wrapped_type ( type_ref) ?;
66
66
67
67
acc. add ( kind. assist_id ( ) , kind. label ( ) , type_ref. syntax ( ) . text_range ( ) , |builder| {
68
- let body = ast:: Expr :: BlockExpr ( body) ;
68
+ let mut editor = builder. make_editor ( & parent) ;
69
+ let make = SyntaxFactory :: new ( ) ;
69
70
70
71
let mut exprs_to_unwrap = Vec :: new ( ) ;
71
72
let tail_cb = & mut |e : & _ | tail_cb_impl ( & mut exprs_to_unwrap, e) ;
72
- walk_expr ( & body , & mut |expr| {
73
- if let Expr :: ReturnExpr ( ret_expr) = expr {
73
+ walk_expr ( & body_expr , & mut |expr| {
74
+ if let ast :: Expr :: ReturnExpr ( ret_expr) = expr {
74
75
if let Some ( ret_expr_arg) = & ret_expr. expr ( ) {
75
76
for_each_tail_expr ( ret_expr_arg, tail_cb) ;
76
77
}
77
78
}
78
79
} ) ;
79
- for_each_tail_expr ( & body , tail_cb) ;
80
+ for_each_tail_expr ( & body_expr , tail_cb) ;
80
81
81
82
let is_unit_type = is_unit_type ( & happy_type) ;
82
83
if is_unit_type {
83
- let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
84
-
85
84
if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
86
85
if token. kind ( ) == SyntaxKind :: WHITESPACE {
87
- text_range = TextRange :: new ( text_range . start ( ) , token. text_range ( ) . end ( ) ) ;
86
+ editor . delete ( token) ;
88
87
}
89
88
}
90
89
91
- builder . delete ( text_range ) ;
90
+ editor . delete ( ret_type . syntax ( ) ) ;
92
91
} else {
93
- builder . replace ( type_ref. syntax ( ) . text_range ( ) , happy_type. syntax ( ) . text ( ) ) ;
92
+ editor . replace ( type_ref. syntax ( ) , happy_type. syntax ( ) ) ;
94
93
}
95
94
96
- for ret_expr_arg in exprs_to_unwrap {
97
- let ret_expr_str = ret_expr_arg. to_string ( ) ;
98
-
99
- let needs_replacing = match kind {
100
- UnwrapperKind :: Option => ret_expr_str. starts_with ( "Some(" ) ,
101
- UnwrapperKind :: Result => {
102
- ret_expr_str. starts_with ( "Ok(" ) || ret_expr_str. starts_with ( "Err(" )
103
- }
104
- } ;
95
+ for tail_expr in exprs_to_unwrap {
96
+ match & tail_expr {
97
+ ast:: Expr :: CallExpr ( call_expr) => {
98
+ let ast:: Expr :: PathExpr ( path_expr) = call_expr. expr ( ) . unwrap ( ) else {
99
+ continue ;
100
+ } ;
101
+
102
+ let path_str = path_expr. path ( ) . unwrap ( ) . to_string ( ) ;
103
+ let needs_replacing = match kind {
104
+ UnwrapperKind :: Option => path_str == "Some" ,
105
+ UnwrapperKind :: Result => path_str == "Ok" || path_str == "Err" ,
106
+ } ;
107
+
108
+ if !needs_replacing {
109
+ continue ;
110
+ }
105
111
106
- if needs_replacing {
107
- let arg_list = ret_expr_arg. syntax ( ) . children ( ) . find_map ( ast:: ArgList :: cast) ;
108
- if let Some ( arg_list) = arg_list {
112
+ let arg_list = call_expr. arg_list ( ) . unwrap ( ) ;
109
113
if is_unit_type {
110
- match ret_expr_arg . syntax ( ) . prev_sibling_or_token ( ) {
111
- // Useful to delete the entire line without leaving trailing whitespaces
112
- Some ( whitespace ) => {
113
- let new_range = TextRange :: new (
114
- whitespace . text_range ( ) . start ( ) ,
115
- ret_expr_arg . syntax ( ) . text_range ( ) . end ( ) ,
116
- ) ;
117
- builder . delete ( new_range ) ;
114
+ let tail_parent = tail_expr
115
+ . syntax ( )
116
+ . parent ( )
117
+ . and_then ( Either :: < ast :: ReturnExpr , ast :: StmtList > :: cast )
118
+ . unwrap ( ) ;
119
+ match tail_parent {
120
+ Either :: Left ( ret_expr ) => {
121
+ editor . replace ( ret_expr . syntax ( ) , make . expr_return ( None ) . syntax ( ) )
118
122
}
119
- None => {
120
- builder. delete ( ret_expr_arg. syntax ( ) . text_range ( ) ) ;
123
+ Either :: Right ( stmt_list) => {
124
+ let new_block = if stmt_list. statements ( ) . next ( ) . is_none ( ) {
125
+ make. expr_empty_block ( )
126
+ } else {
127
+ make. block_expr ( stmt_list. statements ( ) , None )
128
+ } ;
129
+ editor. replace (
130
+ stmt_list. syntax ( ) ,
131
+ new_block. stmt_list ( ) . unwrap ( ) . syntax ( ) ,
132
+ ) ;
121
133
}
122
134
}
123
- } else {
124
- builder. replace (
125
- ret_expr_arg. syntax ( ) . text_range ( ) ,
126
- arg_list. args ( ) . join ( ", " ) ,
127
- ) ;
135
+ } else if let Some ( first_arg) = arg_list. args ( ) . next ( ) {
136
+ editor. replace ( tail_expr. syntax ( ) , first_arg. syntax ( ) ) ;
128
137
}
129
138
}
130
- } else if matches ! ( kind, UnwrapperKind :: Option if ret_expr_str == "None" ) {
131
- builder. replace ( ret_expr_arg. syntax ( ) . text_range ( ) , "()" ) ;
139
+ ast:: Expr :: PathExpr ( path_expr) => {
140
+ let UnwrapperKind :: Option = kind else {
141
+ continue ;
142
+ } ;
143
+
144
+ if path_expr. path ( ) . unwrap ( ) . to_string ( ) != "None" {
145
+ continue ;
146
+ }
147
+
148
+ editor. replace ( path_expr. syntax ( ) , make. expr_unit ( ) . syntax ( ) ) ;
149
+ }
150
+ _ => ( ) ,
132
151
}
133
152
}
153
+
154
+ editor. add_mappings ( make. finish_with_mappings ( ) ) ;
155
+ builder. add_file_edits ( ctx. file_id ( ) , editor) ;
134
156
} )
135
157
}
136
158
@@ -168,12 +190,12 @@ impl UnwrapperKind {
168
190
169
191
fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
170
192
match e {
171
- Expr :: BreakExpr ( break_expr) => {
193
+ ast :: Expr :: BreakExpr ( break_expr) => {
172
194
if let Some ( break_expr_arg) = break_expr. expr ( ) {
173
195
for_each_tail_expr ( & break_expr_arg, & mut |e| tail_cb_impl ( acc, e) )
174
196
}
175
197
}
176
- Expr :: ReturnExpr ( _) => {
198
+ ast :: Expr :: ReturnExpr ( _) => {
177
199
// all return expressions have already been handled by the walk loop
178
200
}
179
201
e => acc. push ( e. clone ( ) ) ,
@@ -238,8 +260,7 @@ fn foo() -> Option<()$0> {
238
260
}
239
261
"# ,
240
262
r#"
241
- fn foo() {
242
- }
263
+ fn foo() {}
243
264
"# ,
244
265
"Unwrap Option return type" ,
245
266
) ;
@@ -254,8 +275,7 @@ fn foo() -> Option<()$0>{
254
275
}
255
276
"# ,
256
277
r#"
257
- fn foo() {
258
- }
278
+ fn foo() {}
259
279
"# ,
260
280
"Unwrap Option return type" ,
261
281
) ;
@@ -1262,8 +1282,7 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
1262
1282
}
1263
1283
"# ,
1264
1284
r#"
1265
- fn foo() {
1266
- }
1285
+ fn foo() {}
1267
1286
"# ,
1268
1287
"Unwrap Result return type" ,
1269
1288
) ;
@@ -1278,8 +1297,7 @@ fn foo() -> Result<(), Box<dyn Error$0>>{
1278
1297
}
1279
1298
"# ,
1280
1299
r#"
1281
- fn foo() {
1282
- }
1300
+ fn foo() {}
1283
1301
"# ,
1284
1302
"Unwrap Result return type" ,
1285
1303
) ;
0 commit comments