Skip to content

Commit ca813d5

Browse files
authored
Add Recursive Segment Tree (#499)
1 parent a01d85a commit ca813d5

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

src/data_structures/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod probabilistic;
99
mod queue;
1010
mod rb_tree;
1111
mod segment_tree;
12+
mod segment_tree_recursive;
1213
mod stack_using_singly_linked_list;
1314
mod treap;
1415
mod trie;
@@ -25,6 +26,7 @@ pub use self::linked_list::LinkedList;
2526
pub use self::queue::Queue;
2627
pub use self::rb_tree::RBTree;
2728
pub use self::segment_tree::SegmentTree;
29+
pub use self::segment_tree_recursive::SegmentTree as SegmentTreeRecursive;
2830
pub use self::stack_using_singly_linked_list::Stack;
2931
pub use self::treap::Treap;
3032
pub use self::trie::Trie;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use std::fmt::Debug;
2+
use std::ops::Range;
3+
4+
pub struct SegmentTree<T: Debug + Default + Ord + Copy> {
5+
len: usize, // length of the represented
6+
tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance)
7+
merge: fn(T, T) -> T, // how we merge two values together
8+
}
9+
10+
impl<T: Debug + Default + Ord + Copy> SegmentTree<T> {
11+
pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self {
12+
let len = arr.len();
13+
let mut sgtr = SegmentTree {
14+
len,
15+
tree: vec![T::default(); 4 * len],
16+
merge,
17+
};
18+
if len != 0 {
19+
sgtr.build_recursive(arr, 1, 0..len, merge);
20+
}
21+
sgtr
22+
}
23+
24+
fn build_recursive(
25+
&mut self,
26+
arr: &[T],
27+
idx: usize,
28+
range: Range<usize>,
29+
merge: fn(T, T) -> T,
30+
) {
31+
if range.end - range.start == 1 {
32+
self.tree[idx] = arr[range.start];
33+
} else {
34+
let mid = (range.start + range.end) / 2;
35+
self.build_recursive(arr, 2 * idx, range.start..mid, merge);
36+
self.build_recursive(arr, 2 * idx + 1, mid..range.end, merge);
37+
self.tree[idx] = merge(self.tree[2 * idx], self.tree[2 * idx + 1]);
38+
}
39+
}
40+
41+
/// Query the range (exclusive)
42+
/// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.)
43+
/// return the aggregate of values over this range otherwise
44+
pub fn query(&self, range: Range<usize>) -> Option<T> {
45+
self.query_recursive(1, 0..self.len, &range)
46+
}
47+
48+
fn query_recursive(
49+
&self,
50+
idx: usize,
51+
element_range: Range<usize>,
52+
query_range: &Range<usize>,
53+
) -> Option<T> {
54+
if element_range.start >= query_range.end || element_range.end <= query_range.start {
55+
return None;
56+
}
57+
if element_range.start >= query_range.start && element_range.end <= query_range.end {
58+
return Some(self.tree[idx]);
59+
}
60+
let mid = (element_range.start + element_range.end) / 2;
61+
let left = self.query_recursive(idx * 2, element_range.start..mid, query_range);
62+
let right = self.query_recursive(idx * 2 + 1, mid..element_range.end, query_range);
63+
match (left, right) {
64+
(None, None) => None,
65+
(None, Some(r)) => Some(r),
66+
(Some(l), None) => Some(l),
67+
(Some(l), Some(r)) => Some((self.merge)(l, r)),
68+
}
69+
}
70+
71+
/// Updates the value at index `idx` in the original array with a new value `val`
72+
pub fn update(&mut self, idx: usize, val: T) {
73+
self.update_recursive(1, 0..self.len, idx, val);
74+
}
75+
76+
fn update_recursive(
77+
&mut self,
78+
idx: usize,
79+
element_range: Range<usize>,
80+
target_idx: usize,
81+
val: T,
82+
) {
83+
println!("{:?}", element_range);
84+
if element_range.start > target_idx || element_range.end <= target_idx {
85+
return;
86+
}
87+
if element_range.end - element_range.start <= 1 && element_range.start == target_idx {
88+
println!("{:?}", element_range);
89+
self.tree[idx] = val;
90+
return;
91+
}
92+
let mid = (element_range.start + element_range.end) / 2;
93+
self.update_recursive(idx * 2, element_range.start..mid, target_idx, val);
94+
self.update_recursive(idx * 2 + 1, mid..element_range.end, target_idx, val);
95+
self.tree[idx] = (self.merge)(self.tree[idx * 2], self.tree[idx * 2 + 1]);
96+
}
97+
}
98+
99+
#[cfg(test)]
100+
mod tests {
101+
use super::*;
102+
use quickcheck::TestResult;
103+
use quickcheck_macros::quickcheck;
104+
use std::cmp::{max, min};
105+
106+
#[test]
107+
fn test_min_segments() {
108+
let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
109+
let min_seg_tree = SegmentTree::from_vec(&vec, min);
110+
assert_eq!(Some(-5), min_seg_tree.query(4..7));
111+
assert_eq!(Some(-30), min_seg_tree.query(0..vec.len()));
112+
assert_eq!(Some(-30), min_seg_tree.query(0..2));
113+
assert_eq!(Some(-4), min_seg_tree.query(1..3));
114+
assert_eq!(Some(-5), min_seg_tree.query(1..7));
115+
}
116+
117+
#[test]
118+
fn test_max_segments() {
119+
let val_at_6 = 6;
120+
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
121+
let mut max_seg_tree = SegmentTree::from_vec(&vec, max);
122+
assert_eq!(Some(15), max_seg_tree.query(0..vec.len()));
123+
let max_4_to_6 = 6;
124+
assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7));
125+
let delta = 2;
126+
max_seg_tree.update(6, val_at_6 + delta);
127+
assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7));
128+
}
129+
130+
#[test]
131+
fn test_sum_segments() {
132+
let val_at_6 = 6;
133+
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
134+
let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b);
135+
for (i, val) in vec.iter().enumerate() {
136+
assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1)));
137+
}
138+
let sum_4_to_6 = sum_seg_tree.query(4..7);
139+
assert_eq!(Some(4), sum_4_to_6);
140+
let delta = 3;
141+
sum_seg_tree.update(6, val_at_6 + delta);
142+
assert_eq!(
143+
sum_4_to_6.unwrap() + delta,
144+
sum_seg_tree.query(4..7).unwrap()
145+
);
146+
}
147+
148+
// Some properties over segment trees:
149+
// When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc.
150+
// When asking for an interval containing a single value, return this value, no matter the merge function
151+
152+
#[quickcheck]
153+
fn check_overall_interval_min(array: Vec<i32>) -> TestResult {
154+
let seg_tree = SegmentTree::from_vec(&array, min);
155+
TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len()))
156+
}
157+
158+
#[quickcheck]
159+
fn check_overall_interval_max(array: Vec<i32>) -> TestResult {
160+
let seg_tree = SegmentTree::from_vec(&array, max);
161+
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
162+
}
163+
164+
#[quickcheck]
165+
fn check_overall_interval_sum(array: Vec<i32>) -> TestResult {
166+
let seg_tree = SegmentTree::from_vec(&array, max);
167+
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
168+
}
169+
170+
#[quickcheck]
171+
fn check_single_interval_min(array: Vec<i32>) -> TestResult {
172+
let seg_tree = SegmentTree::from_vec(&array, min);
173+
for (i, value) in array.into_iter().enumerate() {
174+
let res = seg_tree.query(i..(i + 1));
175+
if res != Some(value) {
176+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
177+
}
178+
}
179+
TestResult::passed()
180+
}
181+
182+
#[quickcheck]
183+
fn check_single_interval_max(array: Vec<i32>) -> TestResult {
184+
let seg_tree = SegmentTree::from_vec(&array, max);
185+
for (i, value) in array.into_iter().enumerate() {
186+
let res = seg_tree.query(i..(i + 1));
187+
if res != Some(value) {
188+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
189+
}
190+
}
191+
TestResult::passed()
192+
}
193+
194+
#[quickcheck]
195+
fn check_single_interval_sum(array: Vec<i32>) -> TestResult {
196+
let seg_tree = SegmentTree::from_vec(&array, max);
197+
for (i, value) in array.into_iter().enumerate() {
198+
let res = seg_tree.query(i..(i + 1));
199+
if res != Some(value) {
200+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
201+
}
202+
}
203+
TestResult::passed()
204+
}
205+
}

0 commit comments

Comments
 (0)