1use std::fmt::Debug;
2
3use super::{
4 config::{ConsoliableWeight, TreapMapConfig},
5 node::Node,
6};
7
8#[derive(Debug, PartialEq, Eq)]
13pub enum SearchDirection<W> {
14 Abort,
17
18 Left,
22
23 Stop,
26
27 Right(W),
34
35 LeftOrStop,
39
40 RightOrStop(W),
48}
49
50impl<W> SearchDirection<W> {
51 #[inline]
52 pub(crate) fn map_into<T, F>(self, f: F) -> SearchDirection<T>
53 where F: FnOnce(W) -> T {
54 match self {
55 SearchDirection::Abort => SearchDirection::Abort,
56 SearchDirection::Left => SearchDirection::Left,
57 SearchDirection::Stop => SearchDirection::Stop,
58 SearchDirection::Right(v) => SearchDirection::Right(f(v)),
59 SearchDirection::LeftOrStop => SearchDirection::LeftOrStop,
60 SearchDirection::RightOrStop(v) => {
61 SearchDirection::RightOrStop(f(v))
62 }
63 }
64 }
65}
66
67pub enum SearchResult<'a, C: TreapMapConfig, W: ConsoliableWeight> {
73 Abort,
78
79 LeftMost,
82
83 Found { base_weight: W, node: &'a Node<C> },
88
89 RightMost(W),
93}
94
95impl<'a, C: TreapMapConfig, W: ConsoliableWeight> SearchResult<'a, C, W> {
96 pub fn maybe_value(&self) -> Option<&'a C::Value> {
97 if let SearchResult::Found { node, .. } = self {
98 Some(&node.value)
99 } else {
100 None
101 }
102 }
103}
104
105#[inline]
121pub fn accumulate_weight_search<C, W, F, E>(
122 root: &Node<C>, mut f: F, extract: E,
123) -> SearchResult<'_, C, W>
124where
125 C: TreapMapConfig,
126 F: FnMut(&W, &Node<C>) -> SearchDirection<W>,
127 W: ConsoliableWeight,
128 E: Fn(&C::Weight) -> &W,
129{
130 use SearchDirection::*;
131
132 let mut node = root;
133 let mut base_weight = W::empty();
134
135 let mut candidate_result = None;
136
137 let mut all_left = true;
138 let mut all_right = true;
139
140 loop {
142 let left_weight = if let Some(ref left) = node.left {
143 W::consolidate(&base_weight, extract(&left.sum_weight))
144 } else {
145 base_weight.clone()
146 };
147 let search_dir = f(&left_weight, &node);
148
149 let found = SearchResult::Found {
150 base_weight: left_weight,
151 node: &node,
152 };
153
154 if matches!(search_dir, Left | LeftOrStop) {
155 all_right = false;
156 }
157
158 if matches!(search_dir, Right(_) | RightOrStop(_)) {
159 all_left = false;
160 }
161
162 let next_node = match search_dir {
163 Right(_) | RightOrStop(_) => &node.right,
164 Left | LeftOrStop => &node.left,
165 Abort => {
166 return candidate_result.unwrap_or(SearchResult::Abort);
167 }
168 Stop => {
169 return found;
170 }
171 };
172
173 if matches!(search_dir, Stop | LeftOrStop | RightOrStop(_)) {
174 candidate_result = Some(found);
175 }
176
177 let right_weight = match search_dir {
178 Right(w) | RightOrStop(w) => Some(w),
179 _ => None,
180 };
181
182 if let Some(found_node) = next_node {
183 node = found_node;
184 if let Some(w) = right_weight {
185 base_weight = w;
186 }
187 } else {
188 if let Some(result) = candidate_result {
189 return result;
190 } else if all_left {
191 return SearchResult::LeftMost;
192 } else if all_right {
193 return SearchResult::RightMost(right_weight.unwrap());
194 } else {
195 return SearchResult::Abort;
196 }
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::{
204 SearchDirection::*,
205 SearchResult::{self, *},
206 };
207 use crate::{ConsoliableWeight, SharedKeyTreapMapConfig, TreapMap};
208 use std::cmp::Ordering::*;
209
210 #[derive(Debug, PartialEq, Eq)]
211 struct SearchTestConfig;
212 impl SharedKeyTreapMapConfig for SearchTestConfig {
213 type Key = usize;
214 type Value = usize;
215 type Weight = usize;
216 }
217
218 fn default_map(n: usize) -> TreapMap<SearchTestConfig> {
219 let mut map = TreapMap::<SearchTestConfig>::new();
220 for i in 1..=n {
221 map.insert(i * 3, i * 3, 3);
222 }
223 map
224 }
225
226 #[test]
227 fn search_no_weight() {
228 let map = default_map(1000);
229 for i in 0usize..=3003 {
230 let res = map
231 .search_no_weight(|node| match i.cmp(&node.value) {
232 Less => Left,
233 Equal => Stop,
234 Greater => Right(()),
235 })
236 .unwrap();
237 if i < 3 {
238 assert_eq!(res, LeftMost)
239 } else if i > 3000 {
240 assert!(matches!(res, RightMost(_)));
241 } else if i % 3 != 0 {
242 assert_eq!(res, SearchResult::Abort);
243 } else {
244 assert_eq!(*res.maybe_value().unwrap(), i);
245 }
246 }
247 }
248
249 #[test]
250 fn search_with_weight() {
251 let map = default_map(1000);
252 for i in 0usize..=3003 {
253 let res = map
254 .search(|left_weight, node| match i.cmp(&node.value) {
255 Less => Left,
256 Equal => Stop,
257 Greater => Right(ConsoliableWeight::consolidate(
258 left_weight,
259 &node.weight,
260 )),
261 })
262 .unwrap();
263 if i < 3 {
264 assert_eq!(res, LeftMost)
265 } else if i > 3000 {
266 assert!(matches!(res, RightMost(_)));
267 } else if i % 3 != 0 {
268 assert_eq!(res, SearchResult::Abort);
269 } else {
270 if let Found { base_weight, node } = res {
271 assert_eq!(base_weight, i - 3);
272 assert_eq!(node.key, i);
273 } else {
274 unreachable!("Unexpected");
275 }
276 }
277 }
278 }
279
280 #[test]
281 fn search_last_vaild() {
282 let map = default_map(1000);
283 for i in 0usize..=3003 {
284 let res = map
285 .search(|left_weight, node| {
286 if node.value <= i {
287 RightOrStop(ConsoliableWeight::consolidate(
288 left_weight,
289 &node.weight,
290 ))
291 } else {
292 Left
293 }
294 })
295 .unwrap();
296 if i < 3 {
297 assert_eq!(res, LeftMost);
298 } else {
299 let mut x = i;
300 if x >= 3000 {
301 x = 3000;
302 }
303 if let Found { base_weight, node } = res {
304 assert_eq!(node.key, x - x % 3);
305 assert_eq!(base_weight, node.key - 3);
306 } else {
307 unreachable!("Unexpected");
308 }
309 }
310 }
311 }
312
313 #[test]
314 fn search_first_valid() {
315 let map = default_map(1000);
316 for i in 0usize..=3003 {
317 let res = map
318 .search(|left_weight, node| {
319 if node.value <= i {
320 Right(ConsoliableWeight::consolidate(
321 left_weight,
322 &node.weight,
323 ))
324 } else {
325 LeftOrStop
326 }
327 })
328 .unwrap();
329 if i >= 3000 {
330 assert_eq!(res, RightMost(3000));
331 } else {
332 if let Found { base_weight, node } = res {
333 assert_eq!(node.key, i - i % 3 + 3);
334 assert_eq!(base_weight, node.key - 3);
335 } else {
336 unreachable!("Unexpected");
337 }
338 }
339 }
340 }
341
342 #[test]
343 fn search_left_most() {
344 let map = default_map(1000);
345 let res = map.search_no_weight(|_| LeftOrStop).unwrap();
346
347 if let Found { node, .. } = res {
348 assert_eq!(node.key, 3);
349 } else {
350 unreachable!("Unexpected");
351 }
352 }
353
354 #[test]
355 fn iter_range() {
356 for n in 1..=1000 {
357 let map: TreapMap<SearchTestConfig> = default_map(n);
358 for i in 0..=(3 * (n + 1)) {
359 let x: Vec<usize> = map.iter_range(&i).map(|x| x.key).collect();
360 let y: Vec<usize> =
361 (3usize..=(3 * n)).step_by(3).filter(|x| *x >= i).collect();
362 assert_eq!(x, y);
363 }
364 }
365 }
366}
367
368mod impl_std_trait {
369 use crate::ConsoliableWeight;
370
371 use super::{Node, SearchResult, TreapMapConfig};
372 use core::{
373 cmp::PartialEq,
374 fmt::{self, Debug, Formatter},
375 };
376
377 impl<'a, C: TreapMapConfig, W: ConsoliableWeight> Debug
378 for SearchResult<'a, C, W>
379 where
380 W: Debug,
381 Node<C>: Debug,
382 {
383 #[inline]
384 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
385 match self {
386 SearchResult::Abort => Formatter::write_str(f, "Abort"),
387 SearchResult::LeftMost => Formatter::write_str(f, "LeftMost"),
388 SearchResult::Found { base_weight, node } => f
389 .debug_struct("Found")
390 .field("base_weight", base_weight)
391 .field("node", node)
392 .finish(),
393 SearchResult::RightMost(w) => {
394 f.debug_tuple("RightMost").field(w).finish()
395 }
396 }
397 }
398 }
399
400 impl<'a, C: TreapMapConfig, W: ConsoliableWeight> PartialEq
401 for SearchResult<'a, C, W>
402 where
403 C::Weight: PartialEq,
404 Node<C>: PartialEq,
405 {
406 fn eq(&self, other: &Self) -> bool {
407 match (self, other) {
408 (
409 Self::Found {
410 base_weight: l_base_weight,
411 node: l_node,
412 },
413 Self::Found {
414 base_weight: r_base_weight,
415 node: r_node,
416 },
417 ) => l_base_weight == r_base_weight && l_node == r_node,
418 (Self::RightMost(l0), Self::RightMost(r0)) => l0 == r0,
419 _ => {
420 core::mem::discriminant(self)
421 == core::mem::discriminant(other)
422 }
423 }
424 }
425 }
426}