throttling/
time_window_bucket.rs1use std::{
6 cmp::Ordering,
7 collections::{binary_heap::PeekMut, BinaryHeap, HashMap},
8 hash::Hash,
9 time::{Duration, Instant},
10};
11
12pub struct TimeWindowBucket<KEY: Eq + Hash + Clone> {
13 interval: Duration,
14 limit: usize,
15 timeouts: BinaryHeap<Item<KEY>>,
16 counters: HashMap<KEY, usize>,
17}
18
19impl<KEY: Eq + Hash + Clone> TimeWindowBucket<KEY> {
20 pub fn new(interval: Duration, limit: usize) -> Self {
21 TimeWindowBucket {
22 interval,
23 limit,
24 timeouts: BinaryHeap::new(),
25 counters: HashMap::new(),
26 }
27 }
28
29 fn refresh(&mut self) {
30 while let Some(item) = self.timeouts.peek_mut() {
31 if item.time.elapsed() <= self.interval {
32 break;
33 }
34
35 let item = PeekMut::pop(item);
36 let counter = self
37 .counters
38 .get_mut(&item.data)
39 .expect("data inconsistent");
40 if *counter <= 1 {
41 self.counters.remove(&item.data);
42 } else {
43 *counter -= 1;
44 }
45 }
46 }
47
48 pub fn try_acquire(&mut self, key: KEY) -> bool {
49 self.refresh();
50
51 let counter = self.counters.entry(key.clone()).or_default();
52 if *counter >= self.limit {
53 return false;
54 }
55
56 *counter += 1;
57 self.timeouts.push(Item::new(key));
58
59 true
60 }
61}
62
63struct Item<T> {
64 time: Instant,
65 data: T,
66}
67
68impl<T> Item<T> {
69 fn new(data: T) -> Self {
70 Item {
71 time: Instant::now(),
72 data,
73 }
74 }
75}
76
77impl<T> PartialEq for Item<T> {
78 fn eq(&self, other: &Self) -> bool { self.time.eq(&other.time) }
79}
80
81impl<T> Eq for Item<T> {}
82
83impl<T> PartialOrd for Item<T> {
84 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
85 Some(self.cmp(other))
86 }
87}
88
89impl<T> Ord for Item<T> {
90 fn cmp(&self, other: &Self) -> Ordering { other.time.cmp(&self.time) }
91}
92
93#[cfg(test)]
94mod tests {
95 use crate::time_window_bucket::TimeWindowBucket;
96 use std::{thread::sleep, time::Duration};
97
98 #[test]
99 fn test_acquire() {
100 let interval = Duration::from_millis(10);
101 let mut bucket = TimeWindowBucket::new(interval, 2);
102
103 assert_eq!(bucket.try_acquire(3), true);
104 assert_eq!(bucket.try_acquire(3), true);
105 assert_eq!(bucket.try_acquire(3), false);
106 assert_eq!(bucket.try_acquire(4), true);
107
108 sleep(interval + Duration::from_millis(1));
109
110 assert_eq!(bucket.try_acquire(3), true);
111 assert_eq!(bucket.try_acquire(3), true);
112 assert_eq!(bucket.try_acquire(3), false);
113 }
114}