Skip to content

Back InitMask by IntervalSet #94450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,14 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> Memory<'mir, 'tcx, M> {
// Zero-sized *destination*.
return Ok(());
};
let src_all_uninit = src_alloc.no_bytes_init(src_range);
// FIXME: This is potentially bad for performance as the init mask could
// be large, but is currently necessary to workaround needing to have
// both the init mask for the src_alloc (shared ref) and the dst_alloc
// (unique ref) available simultaneously. Those are access through
// `self.get_raw{,_mut}` and we can't currently explain to rustc that
// there's no invalidation of the two references.
let src_init_mask = src_alloc.init_mask().clone();

// This checks relocation edges on the src, which needs to happen before
// `prepare_relocation_copy`.
Expand All @@ -1047,8 +1055,6 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> Memory<'mir, 'tcx, M> {
// since we don't want to keep any relocations at the target.
let relocations =
src_alloc.prepare_relocation_copy(self, src_range, dest_offset, num_copies);
// Prepare a copy of the initialization mask.
let compressed = src_alloc.compress_uninit_range(src_range);

// Destination alloc preparations and access hooks.
let (dest_alloc, extra) = self.get_raw_mut(dest_alloc_id)?;
Expand All @@ -1059,7 +1065,7 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> Memory<'mir, 'tcx, M> {
.map_err(|e| e.to_interp_error(dest_alloc_id))?
.as_mut_ptr();

if compressed.no_bytes_init() {
if src_all_uninit {
// Fast path: If all bytes are `uninit` then there is nothing to copy. The target range
// is marked as uninitialized but we otherwise omit changing the byte representation which may
// be arbitrary for uninitialized bytes.
Expand Down Expand Up @@ -1106,8 +1112,9 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> Memory<'mir, 'tcx, M> {
}

// now fill in all the "init" data
dest_alloc.mark_compressed_init_range(
&compressed,
dest_alloc.mark_init_range_repeated(
src_init_mask,
src_range,
alloc_range(dest_offset, size), // just a single copy (i.e., not full `dest_range`)
num_copies,
);
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_data_structures/src/stable_hasher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::sip128::SipHasher128;
use rustc_index::bit_set;
use rustc_index::interval;
use rustc_index::vec;
use smallvec::SmallVec;
use std::hash::{BuildHasher, Hash, Hasher};
Expand Down Expand Up @@ -510,6 +511,12 @@ impl<I: vec::Idx, CTX> HashStable<CTX> for bit_set::BitSet<I> {
}
}

impl<I: vec::Idx, CTX> HashStable<CTX> for interval::IntervalSet<I> {
fn hash_stable(&self, _ctx: &mut CTX, hasher: &mut StableHasher) {
::std::hash::Hash::hash(self, hasher);
}
}

impl<R: vec::Idx, C: vec::Idx, CTX> HashStable<CTX> for bit_set::BitMatrix<R, C> {
fn hash_stable(&self, _ctx: &mut CTX, hasher: &mut StableHasher) {
::std::hash::Hash::hash(self, hasher);
Expand Down
128 changes: 110 additions & 18 deletions compiler/rustc_index/src/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,73 @@ use std::ops::RangeBounds;

use crate::vec::Idx;
use crate::vec::IndexVec;
use rustc_macros::{Decodable, Encodable};
use smallvec::SmallVec;

#[cfg(test)]
mod tests;

/// Stores a set of intervals on the indices.
#[derive(Debug, Clone)]
#[derive(Clone, PartialEq, Eq, Hash, Encodable, Decodable)]
pub struct IntervalSet<I> {
// Start, end
map: SmallVec<[(u32, u32); 4]>,
map: SmallVec<[(I, I); 4]>,
domain: usize,
_data: PhantomData<I>,
}

impl<I: Ord + Idx + Step> std::fmt::Debug for IntervalSet<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
struct AsList<'a, I>(&'a IntervalSet<I>);

impl<'a, I: Idx + Ord + Step> std::fmt::Debug for AsList<'a, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_list().entries(self.0.iter_intervals()).finish()
}
}

let mut s = f.debug_struct("IntervalSet");
s.field("domain_size", &self.domain);
s.field("set", &AsList(&self));
Ok(())
}
}

#[inline]
fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> T {
match range.start_bound() {
Bound::Included(start) => start.index() as u32,
Bound::Excluded(start) => start.index() as u32 + 1,
Bound::Unbounded => 0,
Bound::Included(start) => *start,
Bound::Excluded(start) => T::new(start.index() + 1),
Bound::Unbounded => T::new(0),
}
}

#[inline]
fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<T> {
let end = match range.end_bound() {
Bound::Included(end) => end.index() as u32,
Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
Bound::Unbounded => domain.checked_sub(1)? as u32,
Bound::Included(end) => *end,
Bound::Excluded(end) => T::new(end.index().checked_sub(1)?),
Bound::Unbounded => T::new(domain.checked_sub(1)?),
};
Some(end)
}

impl<I: Idx> IntervalSet<I> {
impl<I: Ord + Idx> IntervalSet<I> {
pub fn new(domain: usize) -> IntervalSet<I> {
IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
}

/// Ensure that the set's domain is at least `min_domain_size`.
pub fn ensure(&mut self, min_domain_size: usize) {
if self.domain < min_domain_size {
self.domain = min_domain_size;
}
}

pub fn domain_size(&self) -> usize {
self.domain
}

pub fn clear(&mut self) {
self.map.clear();
}
Expand All @@ -59,14 +88,18 @@ impl<I: Idx> IntervalSet<I> {
where
I: Step,
{
self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
self.map.iter().map(|&(start, end)| start..I::new(end.index() + 1))
}

/// Returns true if we increased the number of elements present.
pub fn insert(&mut self, point: I) -> bool {
self.insert_range(point..=point)
}

pub fn remove(&mut self, point: I) {
self.remove_range(point..=point);
}

/// Returns true if we increased the number of elements present.
pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
let start = inclusive_start(range.clone());
Expand All @@ -84,10 +117,10 @@ impl<I: Idx> IntervalSet<I> {
// if r.0 == end + 1, then we're actually adjacent, so we want to
// continue to the next range. We're looking here for the first
// range which starts *non-adjacently* to our end.
let next = self.map.partition_point(|r| r.0 <= end + 1);
let next = self.map.partition_point(|r| r.0.index() <= end.index() + 1);
if let Some(last) = next.checked_sub(1) {
let (prev_start, prev_end) = &mut self.map[last];
if *prev_end + 1 >= start {
if prev_end.index() + 1 >= start.index() {
// If the start for the inserted range is adjacent to the
// end of the previous, we can extend the previous range.
if start < *prev_start {
Expand Down Expand Up @@ -134,8 +167,29 @@ impl<I: Idx> IntervalSet<I> {
}
}

pub fn remove_range(&mut self, range: impl RangeBounds<I> + Clone) {
let start = inclusive_start(range.clone());
let Some(end) = inclusive_end(self.domain, range.clone()) else {
// empty range
return;
};
if start > end {
return;
}
// We insert the range, so that any previous gaps are merged into just one large
// range, which we can then split in the next step (either inserting a
// smaller range after or not).
self.insert_range(range);
// Find the range we just inserted.
let idx = self.map.partition_point(|r| r.0 <= end).checked_sub(1).unwrap();
let (prev_start, prev_end) = self.map.remove(idx);
// The range we're looking at contains the range we're removing completely.
assert!(prev_start <= start && end <= prev_end);
self.insert_range(prev_start..start);
self.insert_range((Bound::Excluded(end), Bound::Included(prev_end)));
}

pub fn contains(&self, needle: I) -> bool {
let needle = needle.index() as u32;
let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
// All ranges in the map start after the new range's end
return false;
Expand All @@ -157,6 +211,44 @@ impl<I: Idx> IntervalSet<I> {
self.map.is_empty()
}

/// Returns the minimum (first) element present in the set from `range`.
pub fn first_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
let start = inclusive_start(range.clone());
let Some(end) = inclusive_end(self.domain, range) else {
// empty range
return None;
};
if start > end {
return None;
}
let range = self.map.get(self.map.partition_point(|r| r.1 < start))?;
if range.0 > end { None } else { Some(std::cmp::max(range.0, start)) }
}

/// Returns the minimum (first) element **not** present in the set from `range`.
pub fn first_gap_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
let start = inclusive_start(range.clone());
let Some(end) = inclusive_end(self.domain, range) else {
// empty range
return None;
};
if start > end {
return None;
}
let Some(range) = self.map.get(self.map.partition_point(|r| r.1 < start)) else {
return Some(start);
};
if start < range.0 {
return Some(start);
} else if range.1.index() + 1 < self.domain {
if range.1.index() + 1 <= end.index() {
return Some(I::new(range.1.index() + 1));
}
}

None
}

/// Returns the maximum (last) element present in the set from `range`.
pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
let start = inclusive_start(range.clone());
Expand All @@ -172,12 +264,12 @@ impl<I: Idx> IntervalSet<I> {
return None;
};
let (_, prev_end) = &self.map[last];
if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
if start <= *prev_end { Some(std::cmp::min(*prev_end, end)) } else { None }
}

pub fn insert_all(&mut self) {
self.clear();
self.map.push((0, self.domain.try_into().unwrap()));
self.map.push((I::new(0), I::new(self.domain)));
}

pub fn union(&mut self, other: &IntervalSet<I>) -> bool
Expand Down Expand Up @@ -208,7 +300,7 @@ where
column_size: usize,
}

impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
impl<R: Idx, C: Ord + Step + Idx> SparseIntervalMatrix<R, C> {
pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
SparseIntervalMatrix { rows: IndexVec::new(), column_size }
}
Expand Down
Loading