@@ -663,25 +663,28 @@ impl ast::LetStmt {
663
663
664
664
ted:: remove ( existing_ty. syntax ( ) ) ;
665
665
}
666
+
667
+ // Remove any trailing ws
668
+ if let Some ( last) = self . syntax ( ) . last_token ( ) . filter ( |it| it. kind ( ) == WHITESPACE )
669
+ {
670
+ last. detach ( ) ;
671
+ }
666
672
}
667
673
Some ( new_ty) => {
668
674
if self . colon_token ( ) . is_none ( ) {
669
- let mut to_insert: Vec < SyntaxElement > = vec ! [ ] ;
670
-
671
- let position = match self . pat ( ) {
672
- Some ( pat) => Position :: after ( pat. syntax ( ) ) ,
673
- None => {
674
- to_insert. push ( make:: tokens:: single_space ( ) . into ( ) ) ;
675
- Position :: after ( self . let_token ( ) . unwrap ( ) )
676
- }
677
- } ;
678
-
679
- to_insert. push ( make:: token ( T ! [ : ] ) . into ( ) ) ;
680
-
681
- ted:: insert_all_raw ( position, to_insert) ;
675
+ ted:: insert_raw (
676
+ Position :: after (
677
+ self . pat ( ) . expect ( "let stmt should have a pattern" ) . syntax ( ) ,
678
+ ) ,
679
+ make:: token ( T ! [ : ] ) ,
680
+ ) ;
682
681
}
683
682
684
- ted:: insert ( Position :: after ( self . colon_token ( ) . unwrap ( ) ) , new_ty. syntax ( ) ) ;
683
+ if let Some ( old_ty) = self . ty ( ) {
684
+ ted:: replace ( old_ty. syntax ( ) , new_ty. syntax ( ) ) ;
685
+ } else {
686
+ ted:: insert ( Position :: after ( self . colon_token ( ) . unwrap ( ) ) , new_ty. syntax ( ) ) ;
687
+ }
685
688
}
686
689
}
687
690
}
@@ -1022,6 +1025,37 @@ mod tests {
1022
1025
check ( "let a @ " , "let a" , None ) ;
1023
1026
}
1024
1027
1028
+ #[ test]
1029
+ fn test_let_stmt_set_ty ( ) {
1030
+ #[ track_caller]
1031
+ fn check ( before : & str , expected : & str , ty : Option < ast:: Type > ) {
1032
+ let ty = ty. map ( |it| it. clone_for_update ( ) ) ;
1033
+
1034
+ let let_stmt = ast_mut_from_text :: < ast:: LetStmt > ( & format ! ( "fn f() {{ {before} }}" ) ) ;
1035
+ let_stmt. set_ty ( ty) ;
1036
+
1037
+ let after = ast_mut_from_text :: < ast:: LetStmt > ( & format ! ( "fn f() {{ {expected} }}" ) ) ;
1038
+ assert_eq ! ( let_stmt. to_string( ) , after. to_string( ) , "{let_stmt:#?}\n !=\n {after:#?}" ) ;
1039
+ }
1040
+
1041
+ // adding
1042
+ check ( "let a;" , "let a: ();" , Some ( make:: ty_tuple ( [ ] ) ) ) ;
1043
+ // no semicolon due to it being eaten during error recovery
1044
+ check ( "let a:" , "let a: ()" , Some ( make:: ty_tuple ( [ ] ) ) ) ;
1045
+
1046
+ // replacing
1047
+ check ( "let a: u8;" , "let a: ();" , Some ( make:: ty_tuple ( [ ] ) ) ) ;
1048
+ check ( "let a: u8 = 3;" , "let a: () = 3;" , Some ( make:: ty_tuple ( [ ] ) ) ) ;
1049
+ check ( "let a: = 3;" , "let a: () = 3;" , Some ( make:: ty_tuple ( [ ] ) ) ) ;
1050
+
1051
+ // removing
1052
+ check ( "let a: u8;" , "let a;" , None ) ;
1053
+ check ( "let a:;" , "let a;" , None ) ;
1054
+
1055
+ check ( "let a: u8 = 3;" , "let a = 3;" , None ) ;
1056
+ check ( "let a: = 3;" , "let a = 3;" , None ) ;
1057
+ }
1058
+
1025
1059
#[ test]
1026
1060
fn add_variant_to_empty_enum ( ) {
1027
1061
let variant = make:: variant ( make:: name ( "Bar" ) , None ) . clone_for_update ( ) ;
0 commit comments