Skip to content

Commit 5cd19b8

Browse files
committed
Optimize no-wrap niche discriminant cases
1 parent d956b30 commit 5cd19b8

File tree

6 files changed

+954
-110
lines changed

6 files changed

+954
-110
lines changed

compiler/rustc_abi/src/lib.rs

+34-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use std::fmt;
4343
#[cfg(feature = "nightly")]
4444
use std::iter::Step;
4545
use std::num::{NonZeroUsize, ParseIntError};
46-
use std::ops::{Add, AddAssign, Mul, RangeInclusive, Sub};
46+
use std::ops::{Add, AddAssign, Mul, RangeFull, RangeInclusive, Sub};
4747
use std::str::FromStr;
4848

4949
use bitflags::bitflags;
@@ -1162,12 +1162,45 @@ impl WrappingRange {
11621162
}
11631163

11641164
/// Returns `true` if `size` completely fills the range.
1165+
///
1166+
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
1167+
/// Niche calculations can produce full ranges which are not the canonical one;
1168+
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
11651169
#[inline]
11661170
fn is_full_for(&self, size: Size) -> bool {
11671171
let max_value = size.unsigned_int_max();
11681172
debug_assert!(self.start <= max_value && self.end <= max_value);
11691173
self.start == (self.end.wrapping_add(1) & max_value)
11701174
}
1175+
1176+
/// Checks whether this range is considered non-wrapping when the values are
1177+
/// interpreted as *unsigned* numbers of width `size`.
1178+
///
1179+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1180+
/// and `Err(..)` if the range is full so it depends how you think about it.
1181+
#[inline]
1182+
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1183+
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
1184+
}
1185+
1186+
/// Checks whether this range is considered non-wrapping when the values are
1187+
/// interpreted as *signed* numbers of width `size`.
1188+
///
1189+
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
1190+
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
1191+
///
1192+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1193+
/// and `Err(..)` if the range is full so it depends how you think about it.
1194+
#[inline]
1195+
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1196+
if self.is_full_for(size) {
1197+
Err(..)
1198+
} else {
1199+
let start: i128 = size.sign_extend(self.start);
1200+
let end: i128 = size.sign_extend(self.end);
1201+
Ok(start <= end)
1202+
}
1203+
}
11711204
}
11721205

11731206
impl fmt::Debug for WrappingRange {

compiler/rustc_codegen_ssa/src/mir/operand.rs

+181-41
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use std::fmt;
33
use arrayvec::ArrayVec;
44
use either::Either;
55
use rustc_abi as abi;
6-
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
6+
use rustc_abi::{
7+
Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, VariantIdx, Variants,
8+
};
79
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
810
use rustc_middle::mir::{self, ConstValue};
911
use rustc_middle::ty::Ty;
@@ -510,6 +512,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
510512
);
511513

512514
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
515+
let tag_range = tag_scalar.valid_range(&dl);
516+
let tag_size = tag_scalar.size(&dl);
513517

514518
// We have a subrange `niche_start..=niche_end` inside `range`.
515519
// If the value of the tag is inside this subrange, it's a
@@ -525,53 +529,189 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
525529
// untagged_variant
526530
// }
527531
// However, we will likely be able to emit simpler code.
528-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
529-
// Best case scenario: only one tagged variant. This will
530-
// likely become just a comparison and a jump.
531-
// The algorithm is:
532-
// is_niche = tag == niche_start
533-
// discr = if is_niche {
534-
// niche_start
535-
// } else {
536-
// untagged_variant
537-
// }
532+
533+
// First, the incredibly-common case of a two-variant enum (like
534+
// `Option` or `Result`) where we only need one check.
535+
if relative_max == 0 {
538536
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
539-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
540-
let tagged_discr =
541-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
542-
(is_niche, tagged_discr, 0)
543-
} else {
544-
// The special cases don't apply, so we'll have to go with
545-
// the general algorithm.
546-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
547-
let cast_tag = bx.intcast(relative_discr, cast_to, false);
548-
let is_niche = bx.icmp(
549-
IntPredicate::IntULE,
550-
relative_discr,
551-
bx.cx().const_uint(tag_llty, relative_max as u64),
552-
);
553-
554-
// Thanks to parameter attributes and load metadata, LLVM already knows
555-
// the general valid range of the tag. It's possible, though, for there
556-
// to be an impossible value *in the middle*, which those ranges don't
557-
// communicate, so it's worth an `assume` to let the optimizer know.
558-
if niche_variants.contains(&untagged_variant)
559-
&& bx.cx().sess().opts.optimize != OptLevel::No
537+
let is_natural = bx.icmp(IntPredicate::IntNE, tag, niche_start);
538+
return if untagged_variant == VariantIdx::from_u32(1)
539+
&& *niche_variants.start() == VariantIdx::from_u32(0)
560540
{
561-
let impossible =
562-
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
563-
let impossible = bx.cx().const_uint(tag_llty, impossible);
564-
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
565-
bx.assume(ne);
541+
// The polarity of the comparison above is picked so we can
542+
// just extend for `Option<T>`, which has these variants.
543+
bx.zext(is_natural, cast_to)
544+
} else {
545+
let tagged_discr =
546+
bx.cx().const_uint(cast_to, u64::from(niche_variants.start().as_u32()));
547+
let untagged_discr =
548+
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
549+
bx.select(is_natural, untagged_discr, tagged_discr)
550+
};
551+
}
552+
553+
let niche_end =
554+
tag_size.truncate(u128::from(relative_max).wrapping_add(niche_start));
555+
556+
// Next, the layout algorithm prefers to put the niches at one end,
557+
// so look for cases where we don't need to calculate a relative_tag
558+
// at all and can just look at the original tag value directly.
559+
// This also lets us move any possibly-wrapping addition to the end
560+
// where it's easiest to get rid of in the normal uses: it's easy
561+
// to optimize `COMPLICATED + 2 == 7` to `COMPLICATED == (7 - 2)`.
562+
{
563+
// Work in whichever size is wider, because it's possible for
564+
// the untagged variant to be further away from the niches than
565+
// is possible to represent in the smaller type.
566+
let (wide_size, wide_ibty) = if cast_to_layout.size > tag_size {
567+
(cast_to_layout.size, cast_to)
568+
} else {
569+
(tag_size, tag_llty)
570+
};
571+
572+
struct NoWrapData<V> {
573+
wide_tag: V,
574+
is_niche: V,
575+
needs_assume: bool,
576+
wide_niche_to_variant: u128,
577+
wide_niche_untagged: u128,
566578
}
567579

568-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
569-
};
580+
let first_variant = u128::from(niche_variants.start().as_u32());
581+
let untagged_variant = u128::from(untagged_variant.as_u32());
582+
583+
let opt_data = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
584+
let wide_tag = bx.zext(tag, wide_ibty);
585+
let extend = |x| x;
586+
let wide_niche_start = extend(niche_start);
587+
let wide_niche_end = extend(niche_end);
588+
debug_assert!(wide_niche_start <= wide_niche_end);
589+
let wide_first_variant = extend(first_variant);
590+
let wide_untagged_variant = extend(untagged_variant);
591+
let wide_niche_to_variant =
592+
wide_first_variant.wrapping_sub(wide_niche_start);
593+
let wide_niche_untagged = wide_size
594+
.truncate(wide_untagged_variant.wrapping_sub(wide_niche_to_variant));
595+
let (is_niche, needs_assume) = if tag_range.start == niche_start {
596+
let end = bx.cx().const_uint_big(tag_llty, niche_end);
597+
(
598+
bx.icmp(IntPredicate::IntULE, tag, end),
599+
wide_niche_untagged <= wide_niche_end,
600+
)
601+
} else if tag_range.end == niche_end {
602+
let start = bx.cx().const_uint_big(tag_llty, niche_start);
603+
(
604+
bx.icmp(IntPredicate::IntUGE, tag, start),
605+
wide_niche_untagged >= wide_niche_start,
606+
)
607+
} else {
608+
bug!()
609+
};
610+
Some(NoWrapData {
611+
wide_tag,
612+
is_niche,
613+
needs_assume,
614+
wide_niche_to_variant,
615+
wide_niche_untagged,
616+
})
617+
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
618+
let wide_tag = bx.sext(tag, wide_ibty);
619+
let extend = |x| tag_size.sign_extend(x);
620+
let wide_niche_start = extend(niche_start);
621+
let wide_niche_end = extend(niche_end);
622+
debug_assert!(wide_niche_start <= wide_niche_end);
623+
let wide_first_variant = extend(first_variant);
624+
let wide_untagged_variant = extend(untagged_variant);
625+
let wide_niche_to_variant =
626+
wide_first_variant.wrapping_sub(wide_niche_start);
627+
let wide_niche_untagged = wide_size.sign_extend(
628+
wide_untagged_variant
629+
.wrapping_sub(wide_niche_to_variant)
630+
.cast_unsigned(),
631+
);
632+
let (is_niche, needs_assume) = if tag_range.start == niche_start {
633+
let end = bx.cx().const_uint_big(tag_llty, niche_end);
634+
(
635+
bx.icmp(IntPredicate::IntSLE, tag, end),
636+
wide_niche_untagged <= wide_niche_end,
637+
)
638+
} else if tag_range.end == niche_end {
639+
let start = bx.cx().const_uint_big(tag_llty, niche_start);
640+
(
641+
bx.icmp(IntPredicate::IntSGE, tag, start),
642+
wide_niche_untagged >= wide_niche_start,
643+
)
644+
} else {
645+
bug!()
646+
};
647+
Some(NoWrapData {
648+
wide_tag,
649+
is_niche,
650+
needs_assume,
651+
wide_niche_to_variant: wide_niche_to_variant.cast_unsigned(),
652+
wide_niche_untagged: wide_niche_untagged.cast_unsigned(),
653+
})
654+
} else {
655+
None
656+
};
657+
if let Some(NoWrapData {
658+
wide_tag,
659+
is_niche,
660+
needs_assume,
661+
wide_niche_to_variant,
662+
wide_niche_untagged,
663+
}) = opt_data
664+
{
665+
let wide_niche_untagged =
666+
bx.cx().const_uint_big(wide_ibty, wide_niche_untagged);
667+
if needs_assume && bx.cx().sess().opts.optimize != OptLevel::No {
668+
let not_untagged =
669+
bx.icmp(IntPredicate::IntNE, wide_tag, wide_niche_untagged);
670+
bx.assume(not_untagged);
671+
}
672+
673+
let wide_niche = bx.select(is_niche, wide_tag, wide_niche_untagged);
674+
let cast_niche = bx.trunc(wide_niche, cast_to);
675+
let discr = if wide_niche_to_variant == 0 {
676+
cast_niche
677+
} else {
678+
let niche_to_variant =
679+
bx.cx().const_uint_big(cast_to, wide_niche_to_variant);
680+
bx.add(cast_niche, niche_to_variant)
681+
};
682+
return discr;
683+
}
684+
}
685+
686+
// Otherwise the special cases don't apply,
687+
// so we'll have to go with the general algorithm.
688+
let relative_tag = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
689+
let relative_discr = bx.intcast(relative_tag, cast_to, false);
690+
let is_niche = bx.icmp(
691+
IntPredicate::IntULE,
692+
relative_tag,
693+
bx.cx().const_uint(tag_llty, u64::from(relative_max)),
694+
);
695+
696+
// Thanks to parameter attributes and load metadata, LLVM already knows
697+
// the general valid range of the tag. It's possible, though, for there
698+
// to be an impossible value *in the middle*, which those ranges don't
699+
// communicate, so it's worth an `assume` to let the optimizer know.
700+
if niche_variants.contains(&untagged_variant)
701+
&& bx.cx().sess().opts.optimize != OptLevel::No
702+
{
703+
let impossible =
704+
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
705+
let impossible = bx.cx().const_uint(tag_llty, impossible);
706+
let ne = bx.icmp(IntPredicate::IntNE, relative_tag, impossible);
707+
bx.assume(ne);
708+
}
570709

710+
let delta = niche_variants.start().as_u32();
571711
let tagged_discr = if delta == 0 {
572-
tagged_discr
712+
relative_discr
573713
} else {
574-
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
714+
bx.add(relative_discr, bx.cx().const_uint(cast_to, u64::from(delta)))
575715
};
576716

577717
let discr = bx.select(

0 commit comments

Comments
 (0)