treap_map/
map.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5use crate::{
6    config::{ConsoliableWeight, KeyMngTrait},
7    search::{accumulate_weight_search, SearchDirection},
8    update::{ApplyOp, ApplyOpOutcome, InsertOp, RemoveOp},
9    Direction, NoWeight, SearchResult,
10};
11
12use super::{config::TreapMapConfig, node::Node};
13use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
14use rand::{RngCore, SeedableRng};
15use rand_xorshift::XorShiftRng;
16
17/// A treap map data structure.
18///
19/// See [`TreapMapConfig`][crate::TreapMapConfig] for more details.
20pub struct TreapMap<C: TreapMapConfig> {
21    /// The root node of the treap.
22    #[cfg(test)]
23    pub(crate) root: Option<Box<Node<C>>>,
24    #[cfg(not(test))]
25    root: Option<Box<Node<C>>>,
26
27    /// A map for recovering the `sort_key` from the `search_key`.
28    /// This is useful when the `sort_key` is derived from `search_key` and
29    /// `value`.
30    ext_map: C::ExtMap,
31
32    /// A random number generator used for generating priority values for new
33    /// nodes.
34    rng: XorShiftRng,
35}
36
37impl<C: TreapMapConfig> MallocSizeOf for TreapMap<C>
38where
39    Node<C>: MallocSizeOf,
40    C::ExtMap: MallocSizeOf,
41{
42    fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
43        self.root.size_of(ops) + self.ext_map.size_of(ops)
44    }
45}
46
47impl<C: TreapMapConfig> TreapMap<C> {
48    pub fn new() -> TreapMap<C> {
49        TreapMap {
50            root: None,
51            rng: XorShiftRng::from_os_rng(),
52            ext_map: Default::default(),
53        }
54    }
55
56    pub fn new_with_rng(rng: XorShiftRng) -> TreapMap<C> {
57        TreapMap {
58            root: None,
59            rng,
60            ext_map: Default::default(),
61        }
62    }
63
64    pub fn len(&self) -> usize { self.ext_map.len() }
65
66    pub fn is_empty(&self) -> bool { self.ext_map.len() == 0 }
67
68    pub fn contains_key(&self, key: &C::SearchKey) -> bool {
69        self.get(key).is_some()
70    }
71
72    pub fn insert(
73        &mut self, key: C::SearchKey, value: C::Value, weight: C::Weight,
74    ) -> Option<C::Value> {
75        let sort_key = self.ext_map.make_sort_key(&key, &value);
76
77        let node = Node::new(key, value, sort_key, weight, self.rng.next_u64());
78
79        let (result, _, _) = Node::update_inner(
80            &mut self.root,
81            InsertOp {
82                node: Box::new(node),
83                ext_map: &mut self.ext_map,
84            },
85        );
86
87        result
88    }
89
90    pub fn remove(&mut self, key: &C::SearchKey) -> Option<C::Value> {
91        let sort_key = self.ext_map.get_sort_key(&key)?;
92
93        let (result, _, _) = Node::update_inner(
94            &mut self.root,
95            RemoveOp {
96                key: (&sort_key, key),
97                ext_map: &mut self.ext_map,
98            },
99        );
100
101        result
102    }
103
104    /// Updates the value of a node with the given key in the treap map.
105    ///
106    /// # Parameters
107    /// - `key`: The search key of the node to be updated.
108    /// - `update`: A function that is called if a node with the given key
109    ///   already exists. It takes a mutable reference to the node and returns
110    ///   an `ApplyOpOutcome<T>` or a custom error `E`. See
111    ///   [`ApplyOpOutcome`][crate::ApplyOpOutcome] for more details.
112    /// - `insert`: A function that is called if a node with the given key does
113    ///   not exist. It takes a mutable reference to a random number generator
114    ///   (for computing priority for a [`Node`][crate::Node]) and should return
115    ///   a tuple containing a new `Node<C>` and a value of type `T`, or an
116    ///   error of type `E`.
117    ///   - WARNING: The key of the new node must match the key provided to the
118    ///     function.
119    pub fn update<U, I, T, E>(
120        &mut self, key: &C::SearchKey, update: U, insert: I,
121    ) -> Result<T, E>
122    where
123        U: FnOnce(&mut Node<C>) -> Result<ApplyOpOutcome<T>, E>,
124        I: FnOnce(&mut dyn RngCore) -> Result<(Node<C>, T), E>,
125    {
126        let sort_key = if let Some(sort_key) = self.ext_map.get_sort_key(key) {
127            sort_key
128        } else {
129            return match insert(&mut self.rng) {
130                Ok((node, ret)) => {
131                    self.insert(node.key, node.value, node.weight);
132                    Ok(ret)
133                }
134                Err(err) => Err(err),
135            };
136        };
137        let rng = &mut self.rng;
138        let (res, _, _) = Node::update_inner(
139            &mut self.root,
140            ApplyOp {
141                key: (&sort_key, key),
142                update,
143                insert: || insert(rng),
144                ext_map: &mut self.ext_map,
145            },
146        );
147        let (ret, maybe_node) = res?;
148        if let Some(node) = maybe_node {
149            self.insert(node.key, node.value, node.weight);
150        }
151        Ok(ret)
152    }
153
154    pub fn sum_weight(&self) -> C::Weight {
155        match &self.root {
156            Some(node) => node.sum_weight(),
157            None => C::Weight::empty(),
158        }
159    }
160
161    pub fn get(&self, key: &C::SearchKey) -> Option<&C::Value> {
162        let sort_key = self.ext_map.get_sort_key(key)?;
163        self.root.as_ref().and_then(|x| x.get(&sort_key, key))
164    }
165
166    #[inline]
167    pub fn get_by_weight(&self, weight: C::Weight) -> Option<&C::Value>
168    where C::Weight: Ord {
169        use SearchDirection::*;
170        self.search(|base, mid| {
171            if &weight < base {
172                Left
173            } else {
174                let right_base = C::Weight::consolidate(base, &mid.weight);
175                if weight < right_base {
176                    Stop
177                } else {
178                    Right(right_base)
179                }
180            }
181        })?
182        .maybe_value()
183    }
184
185    /// See details in [`crate::accumulate_weight_search`]
186    pub fn search<F>(&self, f: F) -> Option<SearchResult<'_, C, C::Weight>>
187    where F: FnMut(&C::Weight, &Node<C>) -> SearchDirection<C::Weight> {
188        Some(accumulate_weight_search(self.root.as_ref()?, f, |weight| {
189            weight
190        }))
191    }
192
193    /// See details in [`crate::accumulate_weight_search`]
194    /// If the search process does not require accessing 'weight', this function
195    /// can outperform `search` by eliminating the maintenance of the 'weight'
196    /// dimension.
197    pub fn search_no_weight<F>(
198        &self, mut f: F,
199    ) -> Option<SearchResult<'_, C, NoWeight>>
200    where F: FnMut(&Node<C>) -> SearchDirection<()> {
201        static NW: NoWeight = NoWeight;
202        Some(accumulate_weight_search(
203            self.root.as_ref()?,
204            |_, node| f(node).map_into(|_| NoWeight),
205            |_| &NW,
206        ))
207    }
208
209    pub fn iter(&self) -> Iter<'_, C> {
210        let mut iter = Iter { nodes: vec![] };
211        if let Some(ref n) = self.root {
212            iter.nodes.push(&**n);
213            iter.extend_path();
214        }
215        iter
216    }
217
218    pub fn iter_range(&self, key: &C::SearchKey) -> Iter<'_, C>
219    where C: TreapMapConfig<SortKey = ()> {
220        let mut iter = Iter { nodes: vec![] };
221        if let Some(ref n) = self.root {
222            iter.nodes.push(&**n);
223            iter.extend_path_with_key((&(), key));
224        }
225        iter
226    }
227
228    pub fn values(&self) -> impl Iterator<Item = &C::Value> {
229        self.iter().map(|node| &node.value)
230    }
231
232    pub fn key_values(
233        &self,
234    ) -> impl Iterator<Item = (&C::SearchKey, &C::Value)> {
235        self.iter().map(|node| (&node.key, &node.value))
236    }
237
238    #[cfg(any(test, feature = "testonly_code"))]
239    pub fn assert_consistency(&self)
240    where C::Weight: std::fmt::Debug {
241        if let Some(node) = self.root.as_ref() {
242            node.assert_consistency()
243        }
244    }
245}
246
247pub struct Iter<'a, C: TreapMapConfig> {
248    nodes: Vec<&'a Node<C>>,
249}
250
251impl<'a, C: TreapMapConfig> Clone for Iter<'a, C> {
252    fn clone(&self) -> Self {
253        Self {
254            nodes: self.nodes.clone(),
255        }
256    }
257}
258
259impl<'a, C: TreapMapConfig> Iter<'a, C> {
260    fn extend_path(&mut self) {
261        loop {
262            let node = *self.nodes.last().unwrap();
263            match node.left {
264                None => return,
265                Some(ref n) => self.nodes.push(&**n),
266            }
267        }
268    }
269
270    fn extend_path_with_key(&mut self, key: (&C::SortKey, &C::SearchKey)) {
271        loop {
272            let node = *self.nodes.last().unwrap();
273            match C::next_node_dir(key, (&node.sort_key, &node.key)) {
274                Some(Direction::Left) => {
275                    if let Some(left) = &node.left {
276                        self.nodes.push(left);
277                    } else {
278                        return;
279                    }
280                }
281                None => {
282                    return;
283                }
284                Some(Direction::Right) => {
285                    let node = self.nodes.pop().unwrap();
286                    if let Some(right) = &node.right {
287                        self.nodes.push(right);
288                    } else {
289                        return;
290                    }
291                }
292            }
293        }
294    }
295}
296
297impl<'a, C: TreapMapConfig> Iterator for Iter<'a, C> {
298    type Item = &'a Node<C>;
299
300    fn next(&mut self) -> Option<Self::Item> {
301        match self.nodes.pop() {
302            None => None,
303            Some(node) => {
304                if let Some(ref n) = node.right {
305                    self.nodes.push(&**n);
306                    self.extend_path();
307                }
308                Some(&node)
309            }
310        }
311    }
312}