Skip to content

Commit 237ffa3

Browse files
committed
Auto merge of rust-lang#14667 - unexge:nested-types-in-unwrap-result-type, r=HKalbasi
Handle nested types in `unwrap_result_return_type` assist Fixes rust-lang/rust-analyzer#14496
2 parents 797c2f1 + a2ab7ee commit 237ffa3

File tree

1 file changed

+96
-23
lines changed

1 file changed

+96
-23
lines changed

crates/ide-assists/src/handlers/unwrap_result_return_type.rs

+96-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ide_db::{
55
use itertools::Itertools;
66
use syntax::{
77
ast::{self, Expr},
8-
match_ast, AstNode, TextRange, TextSize,
8+
match_ast, AstNode, NodeOrToken, SyntaxKind, TextRange,
99
};
1010

1111
use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
3838
};
3939

4040
let type_ref = &ret_type.ty()?;
41-
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
41+
let Some(hir::Adt::Enum(ret_enum)) = ctx.sema.resolve_type(type_ref)?.as_adt() else { return None; };
4242
let result_enum =
4343
FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
44-
45-
if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
44+
if ret_enum != result_enum {
4645
return None;
4746
}
4847

48+
let Some(ok_type) = unwrap_result_type(type_ref) else { return None; };
49+
4950
acc.add(
5051
AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
5152
"Unwrap Result return type",
@@ -64,26 +65,19 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
6465
});
6566
for_each_tail_expr(&body, tail_cb);
6667

67-
let mut is_unit_type = false;
68-
if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
69-
let inner_type = match inner_type.split_once(',') {
70-
Some((success_inner_type, _)) => success_inner_type,
71-
None => inner_type,
72-
};
73-
let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
74-
if new_ret_type == "()" {
75-
is_unit_type = true;
76-
let text_range = TextRange::new(
77-
ret_type.syntax().text_range().start(),
78-
ret_type.syntax().text_range().end() + TextSize::from(1u32),
79-
);
80-
builder.delete(text_range)
81-
} else {
82-
builder.replace(
83-
type_ref.syntax().text_range(),
84-
inner_type.strip_suffix('>').unwrap_or(inner_type),
85-
)
68+
let is_unit_type = is_unit_type(&ok_type);
69+
if is_unit_type {
70+
let mut text_range = ret_type.syntax().text_range();
71+
72+
if let Some(NodeOrToken::Token(token)) = ret_type.syntax().next_sibling_or_token() {
73+
if token.kind() == SyntaxKind::WHITESPACE {
74+
text_range = TextRange::new(text_range.start(), token.text_range().end());
75+
}
8676
}
77+
78+
builder.delete(text_range);
79+
} else {
80+
builder.replace(type_ref.syntax().text_range(), ok_type.syntax().text());
8781
}
8882

8983
for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +128,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134128
}
135129
}
136130

131+
// Tries to extract `T` from `Result<T, E>`.
132+
fn unwrap_result_type(ty: &ast::Type) -> Option<ast::Type> {
133+
let ast::Type::PathType(path_ty) = ty else { return None; };
134+
let path = path_ty.path()?;
135+
let segment = path.first_segment()?;
136+
let generic_arg_list = segment.generic_arg_list()?;
137+
let generic_args: Vec<_> = generic_arg_list.generic_args().collect();
138+
let ast::GenericArg::TypeArg(ok_type) = generic_args.first()? else { return None; };
139+
ok_type.ty()
140+
}
141+
142+
fn is_unit_type(ty: &ast::Type) -> bool {
143+
let ast::Type::TupleType(tuple) = ty else { return false };
144+
tuple.fields().next().is_none()
145+
}
146+
137147
#[cfg(test)]
138148
mod tests {
139149
use crate::tests::{check_assist, check_assist_not_applicable};
@@ -173,6 +183,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173183
r#"
174184
fn foo() {
175185
}
186+
"#,
187+
);
188+
189+
// Unformatted return type
190+
check_assist(
191+
unwrap_result_return_type,
192+
r#"
193+
//- minicore: result
194+
fn foo() -> Result<(), Box<dyn Error$0>>{
195+
Ok(())
196+
}
197+
"#,
198+
r#"
199+
fn foo() {
200+
}
176201
"#,
177202
);
178203
}
@@ -1014,6 +1039,54 @@ fn foo(the_field: u32) -> u32 {
10141039
}
10151040
the_field
10161041
}
1042+
"#,
1043+
);
1044+
}
1045+
1046+
#[test]
1047+
fn unwrap_result_return_type_nested_type() {
1048+
check_assist(
1049+
unwrap_result_return_type,
1050+
r#"
1051+
//- minicore: result, option
1052+
fn foo() -> Result<Option<i32$0>, ()> {
1053+
Ok(Some(42))
1054+
}
1055+
"#,
1056+
r#"
1057+
fn foo() -> Option<i32> {
1058+
Some(42)
1059+
}
1060+
"#,
1061+
);
1062+
1063+
check_assist(
1064+
unwrap_result_return_type,
1065+
r#"
1066+
//- minicore: result, option
1067+
fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1068+
Ok(None)
1069+
}
1070+
"#,
1071+
r#"
1072+
fn foo() -> Option<Result<i32, ()>> {
1073+
None
1074+
}
1075+
"#,
1076+
);
1077+
1078+
check_assist(
1079+
unwrap_result_return_type,
1080+
r#"
1081+
//- minicore: result, option, iterators
1082+
fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1083+
Ok(Some(42).into_iter())
1084+
}
1085+
"#,
1086+
r#"
1087+
fn foo() -> impl Iterator<Item = i32> {
1088+
Some(42).into_iter()
1089+
}
10171090
"#,
10181091
);
10191092
}

0 commit comments

Comments
 (0)