Skip to content

Commit 0a836e8

Browse files
authored
Merge pull request #238 from NiklasJonsson/get_many_mut
Implement get_disjoint_mut (previously get_many_mut)
2 parents d10de30 + 434d7ac commit 0a836e8

File tree

4 files changed

+298
-1
lines changed

4 files changed

+298
-1
lines changed

Diff for: src/lib.rs

+30
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,33 @@ impl core::fmt::Display for TryReserveError {
269269
#[cfg(feature = "std")]
270270
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
271271
impl std::error::Error for TryReserveError {}
272+
273+
// NOTE: This is copied from the slice module in the std lib.
274+
/// The error type returned by [`get_disjoint_indices_mut`][`IndexMap::get_disjoint_indices_mut`].
275+
///
276+
/// It indicates one of two possible errors:
277+
/// - An index is out-of-bounds.
278+
/// - The same index appeared multiple times in the array.
279+
// (or different but overlapping indices when ranges are provided)
280+
#[derive(Debug, Clone, PartialEq, Eq)]
281+
pub enum GetDisjointMutError {
282+
/// An index provided was out-of-bounds for the slice.
283+
IndexOutOfBounds,
284+
/// Two indices provided were overlapping.
285+
OverlappingIndices,
286+
}
287+
288+
impl core::fmt::Display for GetDisjointMutError {
289+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
290+
let msg = match self {
291+
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
292+
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
293+
};
294+
295+
core::fmt::Display::fmt(msg, f)
296+
}
297+
}
298+
299+
#[cfg(feature = "std")]
300+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
301+
impl std::error::Error for GetDisjointMutError {}

Diff for: src/map.rs

+44-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use std::collections::hash_map::RandomState;
3838

3939
use self::core::IndexMapCore;
4040
use crate::util::{third, try_simplify_range};
41-
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
41+
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};
4242

4343
/// A hash table where the iteration order of the key-value pairs is independent
4444
/// of the hash values of the keys.
@@ -790,6 +790,32 @@ where
790790
}
791791
}
792792

793+
/// Return the values for `N` keys. If any key is duplicated, this function will panic.
794+
///
795+
/// # Examples
796+
///
797+
/// ```
798+
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
799+
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
800+
/// ```
801+
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
802+
where
803+
Q: ?Sized + Hash + Equivalent<K>,
804+
{
805+
let indices = keys.map(|key| self.get_index_of(key));
806+
match self.as_mut_slice().get_disjoint_opt_mut(indices) {
807+
Err(GetDisjointMutError::IndexOutOfBounds) => {
808+
unreachable!(
809+
"Internal error: indices should never be OOB as we got them from get_index_of"
810+
);
811+
}
812+
Err(GetDisjointMutError::OverlappingIndices) => {
813+
panic!("duplicate keys found");
814+
}
815+
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
816+
}
817+
}
818+
793819
/// Remove the key-value pair equivalent to `key` and return
794820
/// its value.
795821
///
@@ -1196,6 +1222,23 @@ impl<K, V, S> IndexMap<K, V, S> {
11961222
Some(IndexedEntry::new(&mut self.core, index))
11971223
}
11981224

1225+
/// Get an array of `N` key-value pairs by `N` indices
1226+
///
1227+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
1228+
///
1229+
/// # Examples
1230+
///
1231+
/// ```
1232+
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
1233+
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
1234+
/// ```
1235+
pub fn get_disjoint_indices_mut<const N: usize>(
1236+
&mut self,
1237+
indices: [usize; N],
1238+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
1239+
self.as_mut_slice().get_disjoint_mut(indices)
1240+
}
1241+
11991242
/// Returns a slice of key-value pairs in the given range of indices.
12001243
///
12011244
/// Valid indices are `0 <= index < self.len()`.

Diff for: src/map/slice.rs

+46
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use super::{
33
ValuesMut,
44
};
55
use crate::util::{slice_eq, try_simplify_range};
6+
use crate::GetDisjointMutError;
67

78
use alloc::boxed::Box;
89
use alloc::vec::Vec;
@@ -270,6 +271,51 @@ impl<K, V> Slice<K, V> {
270271
self.entries
271272
.partition_point(move |a| pred(&a.key, &a.value))
272273
}
274+
275+
/// Get an array of `N` key-value pairs by `N` indices
276+
///
277+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
278+
pub fn get_disjoint_mut<const N: usize>(
279+
&mut self,
280+
indices: [usize; N],
281+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
282+
let indices = indices.map(Some);
283+
let key_values = self.get_disjoint_opt_mut(indices)?;
284+
Ok(key_values.map(Option::unwrap))
285+
}
286+
287+
#[allow(unsafe_code)]
288+
pub(crate) fn get_disjoint_opt_mut<const N: usize>(
289+
&mut self,
290+
indices: [Option<usize>; N],
291+
) -> Result<[Option<(&K, &mut V)>; N], GetDisjointMutError> {
292+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
293+
let len = self.len();
294+
for i in 0..N {
295+
if let Some(idx) = indices[i] {
296+
if idx >= len {
297+
return Err(GetDisjointMutError::IndexOutOfBounds);
298+
} else if indices[..i].contains(&Some(idx)) {
299+
return Err(GetDisjointMutError::OverlappingIndices);
300+
}
301+
}
302+
}
303+
304+
let entries_ptr = self.entries.as_mut_ptr();
305+
let out = indices.map(|idx_opt| {
306+
match idx_opt {
307+
Some(idx) => {
308+
// SAFETY: The base pointer is valid as it comes from a slice and the reference is always
309+
// in-bounds & unique as we've already checked the indices above.
310+
let kv = unsafe { (*(entries_ptr.add(idx))).ref_mut() };
311+
Some(kv)
312+
}
313+
None => None,
314+
}
315+
});
316+
317+
Ok(out)
318+
}
273319
}
274320

275321
impl<'a, K, V> IntoIterator for &'a Slice<K, V> {

Diff for: src/map/tests.rs

+178
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,181 @@ move_index_oob!(test_move_index_out_of_bounds_0_10, 0, 10);
828828
move_index_oob!(test_move_index_out_of_bounds_0_max, 0, usize::MAX);
829829
move_index_oob!(test_move_index_out_of_bounds_10_0, 10, 0);
830830
move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0);
831+
832+
#[test]
833+
fn disjoint_mut_empty_map() {
834+
let mut map: IndexMap<u32, u32> = IndexMap::default();
835+
assert_eq!(
836+
map.get_disjoint_mut([&0, &1, &2, &3]),
837+
[None, None, None, None]
838+
);
839+
}
840+
841+
#[test]
842+
fn disjoint_mut_empty_param() {
843+
let mut map: IndexMap<u32, u32> = IndexMap::default();
844+
map.insert(1, 10);
845+
assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []);
846+
}
847+
848+
#[test]
849+
fn disjoint_mut_single_fail() {
850+
let mut map: IndexMap<u32, u32> = IndexMap::default();
851+
map.insert(1, 10);
852+
assert_eq!(map.get_disjoint_mut([&0]), [None]);
853+
}
854+
855+
#[test]
856+
fn disjoint_mut_single_success() {
857+
let mut map: IndexMap<u32, u32> = IndexMap::default();
858+
map.insert(1, 10);
859+
assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]);
860+
}
861+
862+
#[test]
863+
fn disjoint_mut_multi_success() {
864+
let mut map: IndexMap<u32, u32> = IndexMap::default();
865+
map.insert(1, 100);
866+
map.insert(2, 200);
867+
map.insert(3, 300);
868+
map.insert(4, 400);
869+
assert_eq!(
870+
map.get_disjoint_mut([&1, &2]),
871+
[Some(&mut 100), Some(&mut 200)]
872+
);
873+
assert_eq!(
874+
map.get_disjoint_mut([&1, &3]),
875+
[Some(&mut 100), Some(&mut 300)]
876+
);
877+
assert_eq!(
878+
map.get_disjoint_mut([&3, &1, &4, &2]),
879+
[
880+
Some(&mut 300),
881+
Some(&mut 100),
882+
Some(&mut 400),
883+
Some(&mut 200)
884+
]
885+
);
886+
}
887+
888+
#[test]
889+
fn disjoint_mut_multi_success_unsized_key() {
890+
let mut map: IndexMap<&'static str, u32> = IndexMap::default();
891+
map.insert("1", 100);
892+
map.insert("2", 200);
893+
map.insert("3", 300);
894+
map.insert("4", 400);
895+
896+
assert_eq!(
897+
map.get_disjoint_mut(["1", "2"]),
898+
[Some(&mut 100), Some(&mut 200)]
899+
);
900+
assert_eq!(
901+
map.get_disjoint_mut(["1", "3"]),
902+
[Some(&mut 100), Some(&mut 300)]
903+
);
904+
assert_eq!(
905+
map.get_disjoint_mut(["3", "1", "4", "2"]),
906+
[
907+
Some(&mut 300),
908+
Some(&mut 100),
909+
Some(&mut 400),
910+
Some(&mut 200)
911+
]
912+
);
913+
}
914+
915+
#[test]
916+
fn disjoint_mut_multi_success_borrow_key() {
917+
let mut map: IndexMap<String, u32> = IndexMap::default();
918+
map.insert("1".into(), 100);
919+
map.insert("2".into(), 200);
920+
map.insert("3".into(), 300);
921+
map.insert("4".into(), 400);
922+
923+
assert_eq!(
924+
map.get_disjoint_mut(["1", "2"]),
925+
[Some(&mut 100), Some(&mut 200)]
926+
);
927+
assert_eq!(
928+
map.get_disjoint_mut(["1", "3"]),
929+
[Some(&mut 100), Some(&mut 300)]
930+
);
931+
assert_eq!(
932+
map.get_disjoint_mut(["3", "1", "4", "2"]),
933+
[
934+
Some(&mut 300),
935+
Some(&mut 100),
936+
Some(&mut 400),
937+
Some(&mut 200)
938+
]
939+
);
940+
}
941+
942+
#[test]
943+
fn disjoint_mut_multi_fail_missing() {
944+
let mut map: IndexMap<u32, u32> = IndexMap::default();
945+
map.insert(1, 100);
946+
map.insert(2, 200);
947+
map.insert(3, 300);
948+
map.insert(4, 400);
949+
950+
assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]);
951+
assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]);
952+
assert_eq!(
953+
map.get_disjoint_mut([&1, &5, &4]),
954+
[Some(&mut 100), None, Some(&mut 400)]
955+
);
956+
}
957+
958+
#[test]
959+
#[should_panic]
960+
fn disjoint_mut_multi_fail_duplicate_panic() {
961+
let mut map: IndexMap<u32, u32> = IndexMap::default();
962+
map.insert(1, 100);
963+
map.get_disjoint_mut([&1, &2, &1]);
964+
}
965+
966+
#[test]
967+
fn disjoint_indices_mut_fail_oob() {
968+
let mut map: IndexMap<u32, u32> = IndexMap::default();
969+
map.insert(1, 10);
970+
map.insert(321, 20);
971+
assert_eq!(
972+
map.get_disjoint_indices_mut([1, 3]),
973+
Err(crate::GetDisjointMutError::IndexOutOfBounds)
974+
);
975+
}
976+
977+
#[test]
978+
fn disjoint_indices_mut_empty() {
979+
let mut map: IndexMap<u32, u32> = IndexMap::default();
980+
map.insert(1, 10);
981+
map.insert(321, 20);
982+
assert_eq!(map.get_disjoint_indices_mut([]), Ok([]));
983+
}
984+
985+
#[test]
986+
fn disjoint_indices_mut_success() {
987+
let mut map: IndexMap<u32, u32> = IndexMap::default();
988+
map.insert(1, 10);
989+
map.insert(321, 20);
990+
assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)]));
991+
992+
assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)]));
993+
assert_eq!(
994+
map.get_disjoint_indices_mut([0, 1]),
995+
Ok([(&1, &mut 10), (&321, &mut 20)])
996+
);
997+
}
998+
999+
#[test]
1000+
fn disjoint_indices_mut_fail_duplicate() {
1001+
let mut map: IndexMap<u32, u32> = IndexMap::default();
1002+
map.insert(1, 10);
1003+
map.insert(321, 20);
1004+
assert_eq!(
1005+
map.get_disjoint_indices_mut([1, 0, 1]),
1006+
Err(crate::GetDisjointMutError::OverlappingIndices)
1007+
);
1008+
}

0 commit comments

Comments
 (0)