Skip to content

Commit 853741e

Browse files
authored
enhanced segment tree implementation and more pythonic (#1715)
* enhanced segment tree implementation and more pythonic enhanced segment tree implementation and more pythonic * add doctests for segment tree * add type annotations * unified processing sum min max segment tre * delete source encoding in segment tree * use a generator function instead of returning * add doctests for methods * add doctests for methods * add doctests * fix doctest * fix doctest * fix doctest * fix function parameter and fix determine conditions
1 parent 9bb57fb commit 853741e

File tree

1 file changed

+237
-0
lines changed

1 file changed

+237
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
Segment_tree creates a segment tree with a given array and function,
3+
allowing queries to be done later in log(N) time
4+
function takes 2 values and returns a same type value
5+
"""
6+
7+
from queue import Queue
8+
from collections.abc import Sequence
9+
10+
11+
class SegmentTreeNode(object):
12+
def __init__(self, start, end, val, left=None, right=None):
13+
self.start = start
14+
self.end = end
15+
self.val = val
16+
self.mid = (start + end) // 2
17+
self.left = left
18+
self.right = right
19+
20+
def __str__(self):
21+
return 'val: %s, start: %s, end: %s' % (self.val, self.start, self.end)
22+
23+
24+
class SegmentTree(object):
25+
"""
26+
>>> import operator
27+
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
28+
>>> for node in num_arr.traverse():
29+
... print(node)
30+
...
31+
val: 15, start: 0, end: 4
32+
val: 8, start: 0, end: 2
33+
val: 7, start: 3, end: 4
34+
val: 3, start: 0, end: 1
35+
val: 5, start: 2, end: 2
36+
val: 3, start: 3, end: 3
37+
val: 4, start: 4, end: 4
38+
val: 2, start: 0, end: 0
39+
val: 1, start: 1, end: 1
40+
>>>
41+
>>> num_arr.update(1, 5)
42+
>>> for node in num_arr.traverse():
43+
... print(node)
44+
...
45+
val: 19, start: 0, end: 4
46+
val: 12, start: 0, end: 2
47+
val: 7, start: 3, end: 4
48+
val: 7, start: 0, end: 1
49+
val: 5, start: 2, end: 2
50+
val: 3, start: 3, end: 3
51+
val: 4, start: 4, end: 4
52+
val: 2, start: 0, end: 0
53+
val: 5, start: 1, end: 1
54+
>>>
55+
>>> num_arr.query_range(3, 4)
56+
7
57+
>>> num_arr.query_range(2, 2)
58+
5
59+
>>> num_arr.query_range(1, 3)
60+
13
61+
>>>
62+
>>> max_arr = SegmentTree([2, 1, 5, 3, 4], max)
63+
>>> for node in max_arr.traverse():
64+
... print(node)
65+
...
66+
val: 5, start: 0, end: 4
67+
val: 5, start: 0, end: 2
68+
val: 4, start: 3, end: 4
69+
val: 2, start: 0, end: 1
70+
val: 5, start: 2, end: 2
71+
val: 3, start: 3, end: 3
72+
val: 4, start: 4, end: 4
73+
val: 2, start: 0, end: 0
74+
val: 1, start: 1, end: 1
75+
>>>
76+
>>> max_arr.update(1, 5)
77+
>>> for node in max_arr.traverse():
78+
... print(node)
79+
...
80+
val: 5, start: 0, end: 4
81+
val: 5, start: 0, end: 2
82+
val: 4, start: 3, end: 4
83+
val: 5, start: 0, end: 1
84+
val: 5, start: 2, end: 2
85+
val: 3, start: 3, end: 3
86+
val: 4, start: 4, end: 4
87+
val: 2, start: 0, end: 0
88+
val: 5, start: 1, end: 1
89+
>>>
90+
>>> max_arr.query_range(3, 4)
91+
4
92+
>>> max_arr.query_range(2, 2)
93+
5
94+
>>> max_arr.query_range(1, 3)
95+
5
96+
>>>
97+
>>> min_arr = SegmentTree([2, 1, 5, 3, 4], min)
98+
>>> for node in min_arr.traverse():
99+
... print(node)
100+
...
101+
val: 1, start: 0, end: 4
102+
val: 1, start: 0, end: 2
103+
val: 3, start: 3, end: 4
104+
val: 1, start: 0, end: 1
105+
val: 5, start: 2, end: 2
106+
val: 3, start: 3, end: 3
107+
val: 4, start: 4, end: 4
108+
val: 2, start: 0, end: 0
109+
val: 1, start: 1, end: 1
110+
>>>
111+
>>> min_arr.update(1, 5)
112+
>>> for node in min_arr.traverse():
113+
... print(node)
114+
...
115+
val: 2, start: 0, end: 4
116+
val: 2, start: 0, end: 2
117+
val: 3, start: 3, end: 4
118+
val: 2, start: 0, end: 1
119+
val: 5, start: 2, end: 2
120+
val: 3, start: 3, end: 3
121+
val: 4, start: 4, end: 4
122+
val: 2, start: 0, end: 0
123+
val: 5, start: 1, end: 1
124+
>>>
125+
>>> min_arr.query_range(3, 4)
126+
3
127+
>>> min_arr.query_range(2, 2)
128+
5
129+
>>> min_arr.query_range(1, 3)
130+
3
131+
>>>
132+
133+
"""
134+
def __init__(self, collection: Sequence, function):
135+
self.collection = collection
136+
self.fn = function
137+
if self.collection:
138+
self.root = self._build_tree(0, len(collection) - 1)
139+
140+
def update(self, i, val):
141+
"""
142+
Update an element in log(N) time
143+
:param i: position to be update
144+
:param val: new value
145+
>>> import operator
146+
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
147+
>>> num_arr.update(1, 5)
148+
>>> num_arr.query_range(1, 3)
149+
13
150+
"""
151+
self._update_tree(self.root, i, val)
152+
153+
def query_range(self, i, j):
154+
"""
155+
Get range query value in log(N) time
156+
:param i: left element index
157+
:param j: right element index
158+
:return: element combined in the range [i, j]
159+
>>> import operator
160+
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
161+
>>> num_arr.update(1, 5)
162+
>>> num_arr.query_range(3, 4)
163+
7
164+
>>> num_arr.query_range(2, 2)
165+
5
166+
>>> num_arr.query_range(1, 3)
167+
13
168+
>>>
169+
"""
170+
return self._query_range(self.root, i, j)
171+
172+
def _build_tree(self, start, end):
173+
if start == end:
174+
return SegmentTreeNode(start, end, self.collection[start])
175+
mid = (start + end) // 2
176+
left = self._build_tree(start, mid)
177+
right = self._build_tree(mid + 1, end)
178+
return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right)
179+
180+
def _update_tree(self, node, i, val):
181+
if node.start == i and node.end == i:
182+
node.val = val
183+
return
184+
if i <= node.mid:
185+
self._update_tree(node.left, i, val)
186+
else:
187+
self._update_tree(node.right, i, val)
188+
node.val = self.fn(node.left.val, node.right.val)
189+
190+
def _query_range(self, node, i, j):
191+
if node.start == i and node.end == j:
192+
return node.val
193+
194+
if i <= node.mid:
195+
if j <= node.mid:
196+
# range in left child tree
197+
return self._query_range(node.left, i, j)
198+
else:
199+
# range in left child tree and right child tree
200+
return self.fn(self._query_range(node.left, i, node.mid), self._query_range(node.right, node.mid + 1, j))
201+
else:
202+
# range in right child tree
203+
return self._query_range(node.right, i, j)
204+
205+
def traverse(self):
206+
if self.root is not None:
207+
queue = Queue()
208+
queue.put(self.root)
209+
while not queue.empty():
210+
node = queue.get()
211+
yield node
212+
213+
if node.left is not None:
214+
queue.put(node.left)
215+
216+
if node.right is not None:
217+
queue.put(node.right)
218+
219+
220+
if __name__ == '__main__':
221+
import operator
222+
for fn in [operator.add, max, min]:
223+
print('*' * 50)
224+
arr = SegmentTree([2, 1, 5, 3, 4], fn)
225+
for node in arr.traverse():
226+
print(node)
227+
print()
228+
229+
arr.update(1, 5)
230+
for node in arr.traverse():
231+
print(node)
232+
print()
233+
234+
print(arr.query_range(3, 4)) # 7
235+
print(arr.query_range(2, 2)) # 5
236+
print(arr.query_range(1, 3)) # 13
237+
print()

0 commit comments

Comments
 (0)