treap_map/
search.rs

1use std::fmt::Debug;
2
3use super::{
4    config::{ConsoliableWeight, TreapMapConfig},
5    node::Node,
6};
7
8/// Represents the directions for the search in [`accumulate_weight_search`].
9///
10/// This enum is used by the user-provided function to indicate how the search
11/// should proceed or terminate in `accumulate_weight_search`.
12#[derive(Debug, PartialEq, Eq)]
13pub enum SearchDirection<W> {
14    /// Indicates to abort the search immediately.
15    /// This stops further searching in any subtree.
16    Abort,
17
18    /// Indicates to continue the search in the left subtree.
19    /// This is used when the current search result is unacceptable and the
20    /// search should move left.
21    Left,
22
23    /// Indicates that the current search result is acceptable and the search
24    /// should stop.
25    Stop,
26
27    /// Indicates to continue the search in the right subtree, with the
28    /// provided weight `W`. This is used when the current search result is
29    /// unacceptable and the search should move right. The user function is
30    /// expected to merge the accumulate weight with the node weight and
31    /// provide it in this variant to avoid recalculating it in
32    /// [`accumulate_weight_search`].
33    Right(W),
34
35    /// Indicates that the current search result is acceptable, but the search
36    /// should still continue in the left subtree. If the subtree yields no
37    /// results, the current result is returned.
38    LeftOrStop,
39
40    /// Similar to `LeftOrStop`, but for the right subtree.
41    /// Indicates that the current search result is acceptable, but the search
42    /// should still continue in the right subtree. If the subtree yields
43    /// no results, the current result is returned, along with the merged
44    /// weight. The user function is expected to merge the accumulate
45    /// weight with the node weight and provide it in this variant to avoid
46    /// recalculating it in [`accumulate_weight_search`].
47    RightOrStop(W),
48}
49
50impl<W> SearchDirection<W> {
51    #[inline]
52    pub(crate) fn map_into<T, F>(self, f: F) -> SearchDirection<T>
53    where F: FnOnce(W) -> T {
54        match self {
55            SearchDirection::Abort => SearchDirection::Abort,
56            SearchDirection::Left => SearchDirection::Left,
57            SearchDirection::Stop => SearchDirection::Stop,
58            SearchDirection::Right(v) => SearchDirection::Right(f(v)),
59            SearchDirection::LeftOrStop => SearchDirection::LeftOrStop,
60            SearchDirection::RightOrStop(v) => {
61                SearchDirection::RightOrStop(f(v))
62            }
63        }
64    }
65}
66
67/// Represents the possible outcomes of the `accumulate_weight_search`.
68///
69/// This enum encapsulates the results that can be returned by
70/// `accumulate_weight_search`, indicating the outcome of the search within a
71/// treap map.
72pub enum SearchResult<'a, C: TreapMapConfig, W: ConsoliableWeight> {
73    /// Indicates that the search was aborted.
74    /// This variant is used when no feasible result is found and the search
75    /// position is neither at the extreme left nor the extreme right of
76    /// the treap.
77    Abort,
78
79    /// Indicates that the search reached the leftmost edge of the entire treap
80    /// without finding a feasible result.
81    LeftMost,
82
83    /// Represents a successful search, indicating a feasible result has been
84    /// found. Contains `base_weight`, which is the total weight from the
85    /// leftmost edge up to but not including the current node,
86    /// and a reference to the `node` itself.
87    Found { base_weight: W, node: &'a Node<C> },
88
89    /// Indicates that the search reached the rightmost edge of the entire
90    /// treap without finding a feasible result. Also returns the total
91    /// weight of the entire tree (`RightMost(W)`).
92    RightMost(W),
93}
94
95impl<'a, C: TreapMapConfig, W: ConsoliableWeight> SearchResult<'a, C, W> {
96    pub fn maybe_value(&self) -> Option<&'a C::Value> {
97        if let SearchResult::Found { node, .. } = self {
98            Some(&node.value)
99        } else {
100            None
101        }
102    }
103}
104
105/// Performs a binary search in a treap-map.
106///
107/// This function conducts a binary search within a treap-map structure, where
108/// at each step it can access the accumulated weight from the leftmost node to
109/// the current node.
110///
111/// # Parameters
112/// - `node`: The root node of the treap-map.
113/// - `f`: A search function that takes the accumulated weight from the leftmost
114///   node to the current node (excluding the current node) and the current node
115///   itself. It returns a search direction (see [`SearchDirection`
116///   struct][SearchDirection] for more details).
117/// - `extract`: A function to extract a subset of the weight stored in the
118///   treap-map. This allows for avoiding the reading and maintenance of fields
119///   that are not needed during the search.
120#[inline]
121pub fn accumulate_weight_search<C, W, F, E>(
122    root: &Node<C>, mut f: F, extract: E,
123) -> SearchResult<'_, C, W>
124where
125    C: TreapMapConfig,
126    F: FnMut(&W, &Node<C>) -> SearchDirection<W>,
127    W: ConsoliableWeight,
128    E: Fn(&C::Weight) -> &W,
129{
130    use SearchDirection::*;
131
132    let mut node = root;
133    let mut base_weight = W::empty();
134
135    let mut candidate_result = None;
136
137    let mut all_left = true;
138    let mut all_right = true;
139
140    // Using loops instead of recursion can improve performance by 20%.
141    loop {
142        let left_weight = if let Some(ref left) = node.left {
143            W::consolidate(&base_weight, extract(&left.sum_weight))
144        } else {
145            base_weight.clone()
146        };
147        let search_dir = f(&left_weight, &node);
148
149        let found = SearchResult::Found {
150            base_weight: left_weight,
151            node: &node,
152        };
153
154        if matches!(search_dir, Left | LeftOrStop) {
155            all_right = false;
156        }
157
158        if matches!(search_dir, Right(_) | RightOrStop(_)) {
159            all_left = false;
160        }
161
162        let next_node = match search_dir {
163            Right(_) | RightOrStop(_) => &node.right,
164            Left | LeftOrStop => &node.left,
165            Abort => {
166                return candidate_result.unwrap_or(SearchResult::Abort);
167            }
168            Stop => {
169                return found;
170            }
171        };
172
173        if matches!(search_dir, Stop | LeftOrStop | RightOrStop(_)) {
174            candidate_result = Some(found);
175        }
176
177        let right_weight = match search_dir {
178            Right(w) | RightOrStop(w) => Some(w),
179            _ => None,
180        };
181
182        if let Some(found_node) = next_node {
183            node = found_node;
184            if let Some(w) = right_weight {
185                base_weight = w;
186            }
187        } else {
188            if let Some(result) = candidate_result {
189                return result;
190            } else if all_left {
191                return SearchResult::LeftMost;
192            } else if all_right {
193                return SearchResult::RightMost(right_weight.unwrap());
194            } else {
195                return SearchResult::Abort;
196            }
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::{
204        SearchDirection::*,
205        SearchResult::{self, *},
206    };
207    use crate::{ConsoliableWeight, SharedKeyTreapMapConfig, TreapMap};
208    use std::cmp::Ordering::*;
209
210    #[derive(Debug, PartialEq, Eq)]
211    struct SearchTestConfig;
212    impl SharedKeyTreapMapConfig for SearchTestConfig {
213        type Key = usize;
214        type Value = usize;
215        type Weight = usize;
216    }
217
218    fn default_map(n: usize) -> TreapMap<SearchTestConfig> {
219        let mut map = TreapMap::<SearchTestConfig>::new();
220        for i in 1..=n {
221            map.insert(i * 3, i * 3, 3);
222        }
223        map
224    }
225
226    #[test]
227    fn search_no_weight() {
228        let map = default_map(1000);
229        for i in 0usize..=3003 {
230            let res = map
231                .search_no_weight(|node| match i.cmp(&node.value) {
232                    Less => Left,
233                    Equal => Stop,
234                    Greater => Right(()),
235                })
236                .unwrap();
237            if i < 3 {
238                assert_eq!(res, LeftMost)
239            } else if i > 3000 {
240                assert!(matches!(res, RightMost(_)));
241            } else if i % 3 != 0 {
242                assert_eq!(res, SearchResult::Abort);
243            } else {
244                assert_eq!(*res.maybe_value().unwrap(), i);
245            }
246        }
247    }
248
249    #[test]
250    fn search_with_weight() {
251        let map = default_map(1000);
252        for i in 0usize..=3003 {
253            let res = map
254                .search(|left_weight, node| match i.cmp(&node.value) {
255                    Less => Left,
256                    Equal => Stop,
257                    Greater => Right(ConsoliableWeight::consolidate(
258                        left_weight,
259                        &node.weight,
260                    )),
261                })
262                .unwrap();
263            if i < 3 {
264                assert_eq!(res, LeftMost)
265            } else if i > 3000 {
266                assert!(matches!(res, RightMost(_)));
267            } else if i % 3 != 0 {
268                assert_eq!(res, SearchResult::Abort);
269            } else {
270                if let Found { base_weight, node } = res {
271                    assert_eq!(base_weight, i - 3);
272                    assert_eq!(node.key, i);
273                } else {
274                    unreachable!("Unexpected");
275                }
276            }
277        }
278    }
279
280    #[test]
281    fn search_last_vaild() {
282        let map = default_map(1000);
283        for i in 0usize..=3003 {
284            let res = map
285                .search(|left_weight, node| {
286                    if node.value <= i {
287                        RightOrStop(ConsoliableWeight::consolidate(
288                            left_weight,
289                            &node.weight,
290                        ))
291                    } else {
292                        Left
293                    }
294                })
295                .unwrap();
296            if i < 3 {
297                assert_eq!(res, LeftMost);
298            } else {
299                let mut x = i;
300                if x >= 3000 {
301                    x = 3000;
302                }
303                if let Found { base_weight, node } = res {
304                    assert_eq!(node.key, x - x % 3);
305                    assert_eq!(base_weight, node.key - 3);
306                } else {
307                    unreachable!("Unexpected");
308                }
309            }
310        }
311    }
312
313    #[test]
314    fn search_first_valid() {
315        let map = default_map(1000);
316        for i in 0usize..=3003 {
317            let res = map
318                .search(|left_weight, node| {
319                    if node.value <= i {
320                        Right(ConsoliableWeight::consolidate(
321                            left_weight,
322                            &node.weight,
323                        ))
324                    } else {
325                        LeftOrStop
326                    }
327                })
328                .unwrap();
329            if i >= 3000 {
330                assert_eq!(res, RightMost(3000));
331            } else {
332                if let Found { base_weight, node } = res {
333                    assert_eq!(node.key, i - i % 3 + 3);
334                    assert_eq!(base_weight, node.key - 3);
335                } else {
336                    unreachable!("Unexpected");
337                }
338            }
339        }
340    }
341
342    #[test]
343    fn search_left_most() {
344        let map = default_map(1000);
345        let res = map.search_no_weight(|_| LeftOrStop).unwrap();
346
347        if let Found { node, .. } = res {
348            assert_eq!(node.key, 3);
349        } else {
350            unreachable!("Unexpected");
351        }
352    }
353
354    #[test]
355    fn iter_range() {
356        for n in 1..=1000 {
357            let map: TreapMap<SearchTestConfig> = default_map(n);
358            for i in 0..=(3 * (n + 1)) {
359                let x: Vec<usize> = map.iter_range(&i).map(|x| x.key).collect();
360                let y: Vec<usize> =
361                    (3usize..=(3 * n)).step_by(3).filter(|x| *x >= i).collect();
362                assert_eq!(x, y);
363            }
364        }
365    }
366}
367
368mod impl_std_trait {
369    use crate::ConsoliableWeight;
370
371    use super::{Node, SearchResult, TreapMapConfig};
372    use core::{
373        cmp::PartialEq,
374        fmt::{self, Debug, Formatter},
375    };
376
377    impl<'a, C: TreapMapConfig, W: ConsoliableWeight> Debug
378        for SearchResult<'a, C, W>
379    where
380        W: Debug,
381        Node<C>: Debug,
382    {
383        #[inline]
384        fn fmt(&self, f: &mut Formatter) -> fmt::Result {
385            match self {
386                SearchResult::Abort => Formatter::write_str(f, "Abort"),
387                SearchResult::LeftMost => Formatter::write_str(f, "LeftMost"),
388                SearchResult::Found { base_weight, node } => f
389                    .debug_struct("Found")
390                    .field("base_weight", base_weight)
391                    .field("node", node)
392                    .finish(),
393                SearchResult::RightMost(w) => {
394                    f.debug_tuple("RightMost").field(w).finish()
395                }
396            }
397        }
398    }
399
400    impl<'a, C: TreapMapConfig, W: ConsoliableWeight> PartialEq
401        for SearchResult<'a, C, W>
402    where
403        C::Weight: PartialEq,
404        Node<C>: PartialEq,
405    {
406        fn eq(&self, other: &Self) -> bool {
407            match (self, other) {
408                (
409                    Self::Found {
410                        base_weight: l_base_weight,
411                        node: l_node,
412                    },
413                    Self::Found {
414                        base_weight: r_base_weight,
415                        node: r_node,
416                    },
417                ) => l_base_weight == r_base_weight && l_node == r_node,
418                (Self::RightMost(l0), Self::RightMost(r0)) => l0 == r0,
419                _ => {
420                    core::mem::discriminant(self)
421                        == core::mem::discriminant(other)
422                }
423            }
424        }
425    }
426}