1#[cfg(test)]
2mod tests;
3
4use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
5use std::{cmp::Ordering, collections::HashMap, fmt::Debug, hash};
6
7#[derive(DeriveMallocSizeOf)]
10pub struct HeapMap<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> {
11 data: Vec<Node<K, V>>,
12 mapping: HashMap<K, usize>,
13}
14
15#[derive(Clone, DeriveMallocSizeOf)]
16pub struct Node<K, V: Eq + Ord> {
17 key: K,
18 value: V,
19}
20
21impl<K, V: Eq + Ord> Node<K, V> {
22 pub fn new(key: K, value: V) -> Self { Node { key, value } }
23}
24
25impl<K, V: Eq + Ord> PartialEq for Node<K, V> {
26 fn eq(&self, other: &Self) -> bool { self.value.eq(&other.value) }
27}
28
29impl<K, V: Eq + Ord> Eq for Node<K, V> {}
30
31impl<K, V: Eq + Ord> PartialOrd for Node<K, V> {
32 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33 Some(self.cmp(other))
34 }
35}
36
37impl<K, V: Eq + Ord> Ord for Node<K, V> {
38 fn cmp(&self, other: &Self) -> Ordering { self.value.cmp(&other.value) }
39}
40
41impl<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> Default
42 for HeapMap<K, V>
43{
44 fn default() -> Self { Self::new() }
45}
46
47impl<K: hash::Hash + Eq + Copy + Debug, V: Eq + Ord + Clone> HeapMap<K, V> {
48 pub fn new() -> Self {
49 Self {
50 data: vec![],
51 mapping: HashMap::new(),
52 }
53 }
54
55 pub fn insert(&mut self, key: &K, value: V) -> Option<V> {
58 if self.mapping.contains_key(key) {
59 let old_value = self.update(key, value);
60 Some(old_value)
61 } else {
62 self.append(key, value);
63 None
64 }
65 }
66
67 pub fn remove(&mut self, key: &K) -> Option<V> {
69 let index = self.mapping.remove(key)?;
70 let removed_node = self.data.swap_remove(index);
71 if index != self.data.len() {
72 match self.data[index].cmp(&removed_node) {
74 Ordering::Less => self.sift_down(index),
75 Ordering::Greater => self.sift_up(index),
76 Ordering::Equal => {
77 self.mapping.insert(self.data[index].key, index);
78 }
79 }
80 }
81 Some(removed_node.value)
82 }
83
84 pub fn update_with<F>(&mut self, key: &K, mut update_fn: F)
86 where F: FnMut(&mut V) -> () {
87 let index = match self.mapping.get(&key) {
88 None => {
89 return;
90 }
91 Some(i) => *i,
92 };
93 let origin_node = self.data[index].clone();
94 update_fn(&mut self.data[index].value);
95 match self.data[index].cmp(&origin_node) {
97 Ordering::Less => self.sift_down(index),
98 Ordering::Greater => self.sift_up(index),
99 _ => {}
100 }
101 }
102
103 pub fn top(&self) -> Option<(&K, &V)> {
105 self.data.get(0).map(|node| (&node.key, &node.value))
106 }
107
108 pub fn pop(&mut self) -> Option<(K, V)> {
110 if self.is_empty() {
111 return None;
112 }
113 let item = self.data.swap_remove(0);
114 if !self.is_empty() {
115 self.sift_down(0);
116 }
117 self.mapping.remove(&item.key);
118 Some((item.key, item.value))
119 }
120
121 pub fn get(&self, key: &K) -> Option<&V> {
123 let index = *self.mapping.get(key)?;
124 self.data.get(index).map(|node| &node.value)
125 }
126
127 pub fn clear(&mut self) {
129 self.mapping.clear();
130 self.data.clear();
131 }
132
133 #[inline]
134 pub fn is_empty(&self) -> bool { self.data.is_empty() }
135
136 #[inline]
137 pub fn len(&self) -> usize { self.data.len() }
138
139 pub fn iter(&self) -> impl Iterator<Item = V> + '_ {
140 self.data.iter().map(|f| f.value.clone())
141 }
142
143 fn update(&mut self, key: &K, value: V) -> V {
144 let index = *self.mapping.get(key).unwrap();
145 let origin_node = self.data[index].clone();
146 self.data[index] = Node::new(*key, value);
147 match self.data[index].cmp(&origin_node) {
148 Ordering::Less => self.sift_down(index),
149 Ordering::Greater => self.sift_up(index),
150 _ => {}
151 }
152 origin_node.value
153 }
154
155 fn append(&mut self, key: &K, value: V) {
156 self.data.push(Node::new(*key, value));
157 self.sift_up(self.data.len() - 1);
158 }
159
160 fn sift_up(&mut self, index: usize) {
161 let val = self.data[index].clone();
162 let mut pos = index;
163 while pos > 0 {
164 let parent = (pos - 1) / 2;
165 if self.data[parent] >= val {
166 break;
167 }
168 self.data[pos] = self.data[parent].clone();
169 self.mapping.insert(self.data[pos].key, pos);
170 pos = parent;
171 }
172
173 self.mapping.insert(val.key, pos);
174 self.data[pos] = val;
175 }
176
177 fn sift_down(&mut self, index: usize) {
178 let val = self.data[index].clone();
179 let mut pos = index;
180 let mut child = pos * 2 + 1;
181 while child < self.data.len() {
182 let right = child + 1;
183 if right < self.data.len() && self.data[right] > self.data[child] {
184 child = right;
185 }
186 if val >= self.data[child] {
187 break;
188 }
189 self.data[pos] = self.data[child].clone();
190 self.mapping.insert(self.data[pos].key, pos);
191 pos = child;
192 child = pos * 2 + 1;
193 }
194 self.mapping.insert(val.key, pos);
195 self.data[pos] = val;
196 }
197}