@@ -5,7 +5,7 @@ use ide_db::{
5
5
use itertools:: Itertools ;
6
6
use syntax:: {
7
7
ast:: { self , Expr } ,
8
- match_ast, AstNode , TextRange , TextSize ,
8
+ match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
9
9
} ;
10
10
11
11
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
38
38
} ;
39
39
40
40
let type_ref = & ret_type. ty ( ) ?;
41
- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
41
+ let Some ( hir :: Adt :: Enum ( ret_enum ) ) = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) else { return None ; } ;
42
42
let result_enum =
43
43
FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
44
-
45
- if !matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
44
+ if ret_enum != result_enum {
46
45
return None ;
47
46
}
48
47
48
+ let Some ( ok_type) = unwrap_result_type ( type_ref) else { return None ; } ;
49
+
49
50
acc. add (
50
51
AssistId ( "unwrap_result_return_type" , AssistKind :: RefactorRewrite ) ,
51
52
"Unwrap Result return type" ,
@@ -64,26 +65,19 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
64
65
} ) ;
65
66
for_each_tail_expr ( & body, tail_cb) ;
66
67
67
- let mut is_unit_type = false ;
68
- if let Some ( ( _, inner_type) ) = type_ref. to_string ( ) . split_once ( '<' ) {
69
- let inner_type = match inner_type. split_once ( ',' ) {
70
- Some ( ( success_inner_type, _) ) => success_inner_type,
71
- None => inner_type,
72
- } ;
73
- let new_ret_type = inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ;
74
- if new_ret_type == "()" {
75
- is_unit_type = true ;
76
- let text_range = TextRange :: new (
77
- ret_type. syntax ( ) . text_range ( ) . start ( ) ,
78
- ret_type. syntax ( ) . text_range ( ) . end ( ) + TextSize :: from ( 1u32 ) ,
79
- ) ;
80
- builder. delete ( text_range)
81
- } else {
82
- builder. replace (
83
- type_ref. syntax ( ) . text_range ( ) ,
84
- inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ,
85
- )
68
+ let is_unit_type = is_unit_type ( & ok_type) ;
69
+ if is_unit_type {
70
+ let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
71
+
72
+ if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
73
+ if token. kind ( ) == SyntaxKind :: WHITESPACE {
74
+ text_range = TextRange :: new ( text_range. start ( ) , token. text_range ( ) . end ( ) ) ;
75
+ }
86
76
}
77
+
78
+ builder. delete ( text_range) ;
79
+ } else {
80
+ builder. replace ( type_ref. syntax ( ) . text_range ( ) , ok_type. syntax ( ) . text ( ) ) ;
87
81
}
88
82
89
83
for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +128,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134
128
}
135
129
}
136
130
131
+ // Tries to extract `T` from `Result<T, E>`.
132
+ fn unwrap_result_type ( ty : & ast:: Type ) -> Option < ast:: Type > {
133
+ let ast:: Type :: PathType ( path_ty) = ty else { return None ; } ;
134
+ let path = path_ty. path ( ) ?;
135
+ let segment = path. first_segment ( ) ?;
136
+ let generic_arg_list = segment. generic_arg_list ( ) ?;
137
+ let generic_args: Vec < _ > = generic_arg_list. generic_args ( ) . collect ( ) ;
138
+ let ast:: GenericArg :: TypeArg ( ok_type) = generic_args. first ( ) ? else { return None ; } ;
139
+ ok_type. ty ( )
140
+ }
141
+
142
+ fn is_unit_type ( ty : & ast:: Type ) -> bool {
143
+ let ast:: Type :: TupleType ( tuple) = ty else { return false } ;
144
+ tuple. fields ( ) . next ( ) . is_none ( )
145
+ }
146
+
137
147
#[ cfg( test) ]
138
148
mod tests {
139
149
use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -173,6 +183,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173
183
r#"
174
184
fn foo() {
175
185
}
186
+ "# ,
187
+ ) ;
188
+
189
+ // Unformatted return type
190
+ check_assist (
191
+ unwrap_result_return_type,
192
+ r#"
193
+ //- minicore: result
194
+ fn foo() -> Result<(), Box<dyn Error$0>>{
195
+ Ok(())
196
+ }
197
+ "# ,
198
+ r#"
199
+ fn foo() {
200
+ }
176
201
"# ,
177
202
) ;
178
203
}
@@ -1014,6 +1039,54 @@ fn foo(the_field: u32) -> u32 {
1014
1039
}
1015
1040
the_field
1016
1041
}
1042
+ "# ,
1043
+ ) ;
1044
+ }
1045
+
1046
+ #[ test]
1047
+ fn unwrap_result_return_type_nested_type ( ) {
1048
+ check_assist (
1049
+ unwrap_result_return_type,
1050
+ r#"
1051
+ //- minicore: result, option
1052
+ fn foo() -> Result<Option<i32$0>, ()> {
1053
+ Ok(Some(42))
1054
+ }
1055
+ "# ,
1056
+ r#"
1057
+ fn foo() -> Option<i32> {
1058
+ Some(42)
1059
+ }
1060
+ "# ,
1061
+ ) ;
1062
+
1063
+ check_assist (
1064
+ unwrap_result_return_type,
1065
+ r#"
1066
+ //- minicore: result, option
1067
+ fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1068
+ Ok(None)
1069
+ }
1070
+ "# ,
1071
+ r#"
1072
+ fn foo() -> Option<Result<i32, ()>> {
1073
+ None
1074
+ }
1075
+ "# ,
1076
+ ) ;
1077
+
1078
+ check_assist (
1079
+ unwrap_result_return_type,
1080
+ r#"
1081
+ //- minicore: result, option, iterators
1082
+ fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1083
+ Ok(Some(42).into_iter())
1084
+ }
1085
+ "# ,
1086
+ r#"
1087
+ fn foo() -> impl Iterator<Item = i32> {
1088
+ Some(42).into_iter()
1089
+ }
1017
1090
"# ,
1018
1091
) ;
1019
1092
}
0 commit comments