heap_map/
lib.rs

1#[cfg(test)]
2mod tests;
3
4use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
5use std::{cmp::Ordering, collections::HashMap, fmt::Debug, hash};
6
7/// The `HeapMap` maintain a max heap along with a hash map to support
8/// additional `remove` and `update` operations.
9#[derive(DeriveMallocSizeOf)]
10pub struct HeapMap<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> {
11    data: Vec<Node<K, V>>,
12    mapping: HashMap<K, usize>,
13}
14
15#[derive(Clone, DeriveMallocSizeOf)]
16pub struct Node<K, V: Eq + Ord> {
17    key: K,
18    value: V,
19}
20
21impl<K, V: Eq + Ord> Node<K, V> {
22    pub fn new(key: K, value: V) -> Self { Node { key, value } }
23}
24
25impl<K, V: Eq + Ord> PartialEq for Node<K, V> {
26    fn eq(&self, other: &Self) -> bool { self.value.eq(&other.value) }
27}
28
29impl<K, V: Eq + Ord> Eq for Node<K, V> {}
30
31impl<K, V: Eq + Ord> PartialOrd for Node<K, V> {
32    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33        Some(self.cmp(other))
34    }
35}
36
37impl<K, V: Eq + Ord> Ord for Node<K, V> {
38    fn cmp(&self, other: &Self) -> Ordering { self.value.cmp(&other.value) }
39}
40
41impl<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> Default
42    for HeapMap<K, V>
43{
44    fn default() -> Self { Self::new() }
45}
46
47impl<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> HeapMap<K, V> {
48    pub fn new() -> Self {
49        Self {
50            data: vec![],
51            mapping: HashMap::new(),
52        }
53    }
54
55    /// Insert a K-V into the HeapMap.
56    /// Return the old value if `key` already exist. Return `None` otherwise.
57    pub fn insert(&mut self, key: &K, value: V) -> Option<V> {
58        if self.mapping.contains_key(key) {
59            let old_value = self.update(key, value);
60            Some(old_value)
61        } else {
62            self.append(key, value);
63            None
64        }
65    }
66
67    /// Remove `key` from the HeapMap.
68    pub fn remove(&mut self, key: &K) -> Option<V> {
69        let index = self.mapping.remove(key)?;
70        let removed_node = self.data.swap_remove(index);
71        if index != self.data.len() {
72            // The last node has been swapped to index
73            match self.data[index].cmp(&removed_node) {
74                Ordering::Less => self.sift_down(index),
75                Ordering::Greater => self.sift_up(index),
76                Ordering::Equal => {
77                    self.mapping.insert(self.data[index].key, index);
78                }
79            }
80        }
81        Some(removed_node.value)
82    }
83
84    /// In-place update some fields of a node's value.
85    pub fn update_with<F>(&mut self, key: &K, mut update_fn: F)
86    where F: FnMut(&mut V) -> () {
87        let index = match self.mapping.get(&key) {
88            None => {
89                return;
90            }
91            Some(i) => *i,
92        };
93        let origin_node = self.data[index].clone();
94        update_fn(&mut self.data[index].value);
95        // The order of node is the opposite of the order of this tuple.
96        match self.data[index].cmp(&origin_node) {
97            Ordering::Less => self.sift_down(index),
98            Ordering::Greater => self.sift_up(index),
99            _ => {}
100        }
101    }
102
103    /// Return the top K-V reference tuple.
104    pub fn top(&self) -> Option<(&K, &V)> {
105        self.data.get(0).map(|node| (&node.key, &node.value))
106    }
107
108    /// Pop the top node and return it as a K-V tuple.
109    pub fn pop(&mut self) -> Option<(K, V)> {
110        if self.is_empty() {
111            return None;
112        }
113        let item = self.data.swap_remove(0);
114        if !self.is_empty() {
115            self.sift_down(0);
116        }
117        self.mapping.remove(&item.key);
118        Some((item.key, item.value))
119    }
120
121    /// Get the value reference of `key`.
122    pub fn get(&self, key: &K) -> Option<&V> {
123        let index = *self.mapping.get(key)?;
124        self.data.get(index).map(|node| &node.value)
125    }
126
127    /// Clear all key-values of the HeapMap.
128    pub fn clear(&mut self) {
129        self.mapping.clear();
130        self.data.clear();
131    }
132
133    #[inline]
134    pub fn is_empty(&self) -> bool { self.data.is_empty() }
135
136    #[inline]
137    pub fn len(&self) -> usize { self.data.len() }
138
139    pub fn iter(&self) -> impl Iterator<Item = V> + '_ {
140        self.data.iter().map(|f| f.value.clone())
141    }
142
143    fn update(&mut self, key: &K, value: V) -> V {
144        let index = *self.mapping.get(key).unwrap();
145        let origin_node = self.data[index].clone();
146        self.data[index] = Node::new(*key, value);
147        match self.data[index].cmp(&origin_node) {
148            Ordering::Less => self.sift_down(index),
149            Ordering::Greater => self.sift_up(index),
150            _ => {}
151        }
152        origin_node.value
153    }
154
155    fn append(&mut self, key: &K, value: V) {
156        self.data.push(Node::new(*key, value));
157        self.sift_up(self.data.len() - 1);
158    }
159
160    fn sift_up(&mut self, index: usize) {
161        let val = self.data[index].clone();
162        let mut pos = index;
163        while pos > 0 {
164            let parent = (pos - 1) / 2;
165            if self.data[parent] >= val {
166                break;
167            }
168            self.data[pos] = self.data[parent].clone();
169            self.mapping.insert(self.data[pos].key, pos);
170            pos = parent;
171        }
172
173        self.mapping.insert(val.key, pos);
174        self.data[pos] = val;
175    }
176
177    fn sift_down(&mut self, index: usize) {
178        let val = self.data[index].clone();
179        let mut pos = index;
180        let mut child = pos * 2 + 1;
181        while child < self.data.len() {
182            let right = child + 1;
183            if right < self.data.len() && self.data[right] > self.data[child] {
184                child = right;
185            }
186            if val >= self.data[child] {
187                break;
188            }
189            self.data[pos] = self.data[child].clone();
190            self.mapping.insert(self.data[pos].key, pos);
191            pos = child;
192            child = pos * 2 + 1;
193        }
194        self.mapping.insert(val.key, pos);
195        self.data[pos] = val;
196    }
197}