|
1 |
| -use std::iter; |
| 1 | +use std::{iter, usize}; |
2 | 2 |
|
| 3 | +use rustc_const_eval::const_eval::mk_eval_cx_for_const_val; |
| 4 | +use rustc_index::bit_set::BitSet; |
3 | 5 | use rustc_index::IndexSlice;
|
4 | 6 | use rustc_middle::mir::patch::MirPatch;
|
5 | 7 | use rustc_middle::mir::*;
|
| 8 | +use rustc_middle::ty; |
6 | 9 | use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
|
| 10 | +use rustc_middle::ty::util::Discr; |
7 | 11 | use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
|
| 12 | +use rustc_mir_dataflow::impls::{borrowed_locals, MaybeTransitiveLiveLocals}; |
| 13 | +use rustc_mir_dataflow::Analysis; |
8 | 14 | use rustc_target::abi::Integer;
|
9 | 15 | use rustc_type_ir::TyKind::*;
|
10 | 16 |
|
@@ -48,6 +54,10 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
|
48 | 54 | should_cleanup = true;
|
49 | 55 | continue;
|
50 | 56 | }
|
| 57 | + if simplify_to_copy(tcx, body, bb_idx, param_env).is_some() { |
| 58 | + should_cleanup = true; |
| 59 | + continue; |
| 60 | + } |
51 | 61 | }
|
52 | 62 |
|
53 | 63 | if should_cleanup {
|
@@ -519,3 +529,212 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
|
519 | 529 | }
|
520 | 530 | }
|
521 | 531 | }
|
| 532 | + |
| 533 | +/// This is primarily used to merge these copy statements that simplified the canonical enum clone method by GVN. |
| 534 | +/// The GVN simplified |
| 535 | +/// ```ignore (syntax-highlighting-only) |
| 536 | +/// match a { |
| 537 | +/// Foo::A(x) => Foo::A(*x), |
| 538 | +/// Foo::B => Foo::B |
| 539 | +/// } |
| 540 | +/// ``` |
| 541 | +/// to |
| 542 | +/// ```ignore (syntax-highlighting-only) |
| 543 | +/// match a { |
| 544 | +/// Foo::A(_x) => a, // copy a |
| 545 | +/// Foo::B => Foo::B |
| 546 | +/// } |
| 547 | +/// ``` |
| 548 | +/// This function will simplify into a copy statement. |
| 549 | +fn simplify_to_copy<'tcx>( |
| 550 | + tcx: TyCtxt<'tcx>, |
| 551 | + body: &mut Body<'tcx>, |
| 552 | + switch_bb_idx: BasicBlock, |
| 553 | + param_env: ParamEnv<'tcx>, |
| 554 | +) -> Option<()> { |
| 555 | + // To save compile time, only consider the first BB has a switch terminator. |
| 556 | + if switch_bb_idx != START_BLOCK { |
| 557 | + return None; |
| 558 | + } |
| 559 | + let bbs = &body.basic_blocks; |
| 560 | + // Check if the copy source matches the following pattern. |
| 561 | + // _2 = discriminant(*_1); // "*_1" is the expected the copy source. |
| 562 | + // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; |
| 563 | + let &Statement { |
| 564 | + kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(expected_src_place))), |
| 565 | + .. |
| 566 | + } = bbs[switch_bb_idx].statements.last()? |
| 567 | + else { |
| 568 | + return None; |
| 569 | + }; |
| 570 | + let expected_src_ty = expected_src_place.ty(body.local_decls(), tcx); |
| 571 | + if !expected_src_ty.ty.is_enum() || expected_src_ty.variant_index.is_some() { |
| 572 | + return None; |
| 573 | + } |
| 574 | + // To save compile time, only consider the copy source is assigned to the return place. |
| 575 | + let expected_dest_place = Place::return_place(); |
| 576 | + let expected_dest_ty = expected_dest_place.ty(body.local_decls(), tcx); |
| 577 | + if expected_dest_ty.ty != expected_src_ty.ty || expected_dest_ty.variant_index.is_some() { |
| 578 | + return None; |
| 579 | + } |
| 580 | + let targets = match bbs[switch_bb_idx].terminator().kind { |
| 581 | + TerminatorKind::SwitchInt { ref discr, ref targets, .. } |
| 582 | + if discr.place() == Some(discr_place) => |
| 583 | + { |
| 584 | + targets |
| 585 | + } |
| 586 | + _ => return None, |
| 587 | + }; |
| 588 | + // We require that the possible target blocks all be distinct. |
| 589 | + if !targets.is_distinct() { |
| 590 | + return None; |
| 591 | + } |
| 592 | + if !bbs[targets.otherwise()].is_empty_unreachable() { |
| 593 | + return None; |
| 594 | + } |
| 595 | + // Check that destinations are identical, and if not, then don't optimize this block. |
| 596 | + let mut target_iter = targets.iter(); |
| 597 | + let first_terminator_kind = &bbs[target_iter.next().unwrap().1].terminator().kind; |
| 598 | + if !target_iter |
| 599 | + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) |
| 600 | + { |
| 601 | + return None; |
| 602 | + } |
| 603 | + |
| 604 | + let borrowed_locals = borrowed_locals(body); |
| 605 | + let mut live = None; |
| 606 | + |
| 607 | + for (index, target_bb) in targets.iter() { |
| 608 | + let stmts = &bbs[target_bb].statements; |
| 609 | + if stmts.is_empty() { |
| 610 | + return None; |
| 611 | + } |
| 612 | + if let [Statement { kind: StatementKind::Assign(box (place, rvalue)), .. }] = |
| 613 | + bbs[target_bb].statements.as_slice() |
| 614 | + { |
| 615 | + let dest_ty = place.ty(body.local_decls(), tcx); |
| 616 | + if dest_ty.ty != expected_src_ty.ty || dest_ty.variant_index.is_some() { |
| 617 | + return None; |
| 618 | + } |
| 619 | + let ty::Adt(def, _) = dest_ty.ty.kind() else { |
| 620 | + return None; |
| 621 | + }; |
| 622 | + if expected_dest_place != *place { |
| 623 | + return None; |
| 624 | + } |
| 625 | + match rvalue { |
| 626 | + // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`. |
| 627 | + Rvalue::Use(Operand::Constant(box constant)) |
| 628 | + if let Const::Val(const_, ty) = constant.const_ => |
| 629 | + { |
| 630 | + let (ecx, op) = |
| 631 | + mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?; |
| 632 | + let variant = ecx.read_discriminant(&op).ok()?; |
| 633 | + if !def.variants()[variant].fields.is_empty() { |
| 634 | + return None; |
| 635 | + } |
| 636 | + let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?; |
| 637 | + if val != index { |
| 638 | + return None; |
| 639 | + } |
| 640 | + } |
| 641 | + Rvalue::Use(Operand::Copy(src_place)) if *src_place == expected_src_place => {} |
| 642 | + // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`. |
| 643 | + Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields) |
| 644 | + if fields.is_empty() |
| 645 | + && let Some(Discr { val, .. }) = |
| 646 | + expected_src_ty.ty.discriminant_for_variant(tcx, *variant_index) |
| 647 | + && val == index => {} |
| 648 | + _ => return None, |
| 649 | + } |
| 650 | + } else { |
| 651 | + // If the BB contains more than one statement, we have to check if these statements can be ignored. |
| 652 | + let mut lived_stmts: BitSet<usize> = |
| 653 | + BitSet::new_filled(bbs[target_bb].statements.len()); |
| 654 | + let mut expected_copy_stmt = None; |
| 655 | + for (statement_index, statement) in bbs[target_bb].statements.iter().enumerate().rev() { |
| 656 | + let loc = Location { block: target_bb, statement_index }; |
| 657 | + if let StatementKind::Assign(assign) = &statement.kind { |
| 658 | + if !assign.1.is_safe_to_remove() { |
| 659 | + return None; |
| 660 | + } |
| 661 | + } |
| 662 | + match &statement.kind { |
| 663 | + StatementKind::Assign(box (place, _)) |
| 664 | + | StatementKind::SetDiscriminant { place: box place, .. } |
| 665 | + | StatementKind::Deinit(box place) => { |
| 666 | + if place.is_indirect() || borrowed_locals.contains(place.local) { |
| 667 | + return None; |
| 668 | + } |
| 669 | + let live = live.get_or_insert_with(|| { |
| 670 | + MaybeTransitiveLiveLocals::new(&borrowed_locals) |
| 671 | + .into_engine(tcx, body) |
| 672 | + .iterate_to_fixpoint() |
| 673 | + .into_results_cursor(body) |
| 674 | + }); |
| 675 | + live.seek_before_primary_effect(loc); |
| 676 | + if !live.get().contains(place.local) { |
| 677 | + lived_stmts.remove(statement_index); |
| 678 | + } else if let StatementKind::Assign(box ( |
| 679 | + _, |
| 680 | + Rvalue::Use(Operand::Copy(src_place)), |
| 681 | + )) = statement.kind |
| 682 | + && expected_copy_stmt.is_none() |
| 683 | + && expected_src_place == src_place |
| 684 | + && expected_dest_place == *place |
| 685 | + { |
| 686 | + // There is only one statement that cannot be ignored that can be used as an expected copy statement. |
| 687 | + expected_copy_stmt = Some(statement_index); |
| 688 | + } else { |
| 689 | + return None; |
| 690 | + } |
| 691 | + } |
| 692 | + StatementKind::StorageLive(_) |
| 693 | + | StatementKind::StorageDead(_) |
| 694 | + | StatementKind::Nop => (), |
| 695 | + |
| 696 | + StatementKind::Retag(_, _) |
| 697 | + | StatementKind::Coverage(_) |
| 698 | + | StatementKind::Intrinsic(_) |
| 699 | + | StatementKind::ConstEvalCounter |
| 700 | + | StatementKind::PlaceMention(_) |
| 701 | + | StatementKind::FakeRead(_) |
| 702 | + | StatementKind::AscribeUserType(_, _) => { |
| 703 | + return None; |
| 704 | + } |
| 705 | + } |
| 706 | + } |
| 707 | + let expected_copy_stmt = expected_copy_stmt?; |
| 708 | + // We can ignore the paired StorageLive and StorageDead. |
| 709 | + let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len()); |
| 710 | + for stmt_index in lived_stmts.iter() { |
| 711 | + let statement = &bbs[target_bb].statements[stmt_index]; |
| 712 | + match &statement.kind { |
| 713 | + StatementKind::Assign(_) if expected_copy_stmt == stmt_index => {} |
| 714 | + StatementKind::StorageLive(local) |
| 715 | + if *local != expected_dest_place.local |
| 716 | + && storage_live_locals.insert(*local) => {} |
| 717 | + StatementKind::StorageDead(local) |
| 718 | + if *local != expected_dest_place.local |
| 719 | + && storage_live_locals.remove(*local) => {} |
| 720 | + StatementKind::Nop => {} |
| 721 | + _ => return None, |
| 722 | + } |
| 723 | + } |
| 724 | + if !storage_live_locals.is_empty() { |
| 725 | + return None; |
| 726 | + } |
| 727 | + } |
| 728 | + } |
| 729 | + let statement_index = bbs[switch_bb_idx].statements.len(); |
| 730 | + let parent_end = Location { block: switch_bb_idx, statement_index }; |
| 731 | + let mut patch = MirPatch::new(body); |
| 732 | + patch.add_assign( |
| 733 | + parent_end, |
| 734 | + expected_dest_place, |
| 735 | + Rvalue::Use(Operand::Copy(expected_src_place)), |
| 736 | + ); |
| 737 | + patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone()); |
| 738 | + patch.apply(body); |
| 739 | + Some(()) |
| 740 | +} |
0 commit comments