1use malloc_size_of::MallocSizeOf;
6use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
7use parking_lot::Mutex;
8use std::{
9 cmp::{max, min},
10 collections::HashMap,
11 fs::read_to_string,
12 hash::Hash,
13 str::FromStr,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17
18#[derive(Debug, Eq, PartialEq)]
19pub enum ThrottleResult {
20 Success,
21 Throttled(Duration),
22 AlreadyThrottled,
23}
24
25#[derive(DeriveMallocSizeOf)]
26pub struct ThrottleTokens {
27 max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64, }
32
33impl ThrottleTokens {
34 pub fn new(
35 max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64,
36 ) -> Self {
37 ThrottleTokens {
38 max_tokens,
39 cur_tokens,
40 recharge_rate,
41 default_cost,
42 }
43 }
44}
45
46#[derive(DeriveMallocSizeOf)]
47pub struct TokenBucket {
48 cpu_tokens: ThrottleTokens,
49 message_size_tokens: ThrottleTokens,
50 last_update: Instant,
51 throttled_until: Option<Instant>,
53 throttled_counter: u64,
57 max_throttled_counter: u64,
58}
59
60impl TokenBucket {
61 pub fn new(
62 max_cpu_tokens: u64, cur_cpu_tokens: u64, cpu_token_recharge_rate: u64,
63 default_cpu_cost: u64, max_message_tokens: u64,
64 cur_message_tokens: u64, message_token_recharge_rate: u64,
65 default_message_cost: u64,
66 ) -> Self {
67 assert!(cur_cpu_tokens <= max_cpu_tokens);
68 assert!(cur_message_tokens <= max_message_tokens);
69
70 TokenBucket {
71 cpu_tokens: ThrottleTokens::new(
72 max_cpu_tokens,
73 cur_cpu_tokens,
74 cpu_token_recharge_rate,
75 default_cpu_cost,
76 ),
77 message_size_tokens: ThrottleTokens::new(
78 max_message_tokens,
79 cur_message_tokens,
80 message_token_recharge_rate,
81 default_message_cost,
82 ),
83 last_update: Instant::now(),
84 throttled_until: None,
85 throttled_counter: 0,
86 max_throttled_counter: 0,
87 }
88 }
89
90 pub fn full(
91 max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
92 default_cpu_cost: u64, max_message_tokens: u64,
93 message_token_recharge_rate: u64, default_message_cost: u64,
94 ) -> Self {
95 Self::new(
96 max_cpu_tokens,
97 max_cpu_tokens,
98 cpu_token_recharge_rate,
99 default_cpu_cost,
100 max_message_tokens,
101 max_message_tokens,
102 message_token_recharge_rate,
103 default_message_cost,
104 )
105 }
106
107 pub fn empty(
108 max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
109 default_cpu_cost: u64, max_message_tokens: u64,
110 message_token_recharge_rate: u64, default_message_cost: u64,
111 ) -> Self {
112 Self::new(
113 max_cpu_tokens,
114 0, cpu_token_recharge_rate,
116 default_cpu_cost,
117 max_message_tokens,
118 0, message_token_recharge_rate,
120 default_message_cost,
121 )
122 }
123
124 pub fn set_max_throttled_counter(&mut self, max_throttled_counter: u64) {
125 self.max_throttled_counter = max_throttled_counter;
126 }
127
128 fn refresh(&mut self, now: Instant) {
129 let elapsed_secs = (now - self.last_update).as_secs();
130 if elapsed_secs == 0 {
131 return;
132 }
133
134 let cpu_recharged = self.cpu_tokens.recharge_rate * elapsed_secs;
135 self.cpu_tokens.cur_tokens = min(
136 self.cpu_tokens.max_tokens,
137 self.cpu_tokens.cur_tokens + cpu_recharged,
138 );
139 let message_recharged =
140 self.message_size_tokens.recharge_rate * elapsed_secs;
141 self.message_size_tokens.cur_tokens = min(
142 self.message_size_tokens.max_tokens,
143 self.message_size_tokens.cur_tokens + message_recharged,
144 );
145 self.last_update += Duration::from_secs(elapsed_secs);
146 }
147
148 fn try_acquire_cost(
149 &mut self, cpu_cost: u64, message_size_cost: u64,
150 ) -> Result<(), Duration> {
151 let now = Instant::now();
152
153 self.refresh(now);
154
155 if cpu_cost <= self.cpu_tokens.cur_tokens
156 && message_size_cost <= self.message_size_tokens.cur_tokens
157 {
158 self.cpu_tokens.cur_tokens -= cpu_cost;
160 self.message_size_tokens.cur_tokens -= message_size_cost;
161 return Ok(());
162 }
163
164 let cpu_recharge_secs = if cpu_cost > self.cpu_tokens.cur_tokens {
166 ((cpu_cost - self.cpu_tokens.cur_tokens) as f64
167 / self.cpu_tokens.recharge_rate as f64)
168 .ceil() as u64
169 } else {
170 0
171 };
172 let message_recharge_secs = if message_size_cost
173 > self.message_size_tokens.cur_tokens
174 {
175 ((message_size_cost - self.message_size_tokens.cur_tokens) as f64
176 / self.message_size_tokens.recharge_rate as f64)
177 .ceil() as u64
178 } else {
179 0
180 };
181 let recharge_secs = max(cpu_recharge_secs, message_recharge_secs);
182 Err(self.last_update + Duration::from_secs(recharge_secs) - now)
186 }
187
188 pub fn throttle_default(&mut self) -> ThrottleResult {
189 self.throttle(
190 self.cpu_tokens.default_cost,
191 self.message_size_tokens.default_cost,
192 )
193 }
194
195 pub fn throttle(
196 &mut self, cpu_cost: u64, message_size_cost: u64,
197 ) -> ThrottleResult {
198 let now = Instant::now();
199
200 if let Some(until) = self.throttled_until {
202 if now < until {
203 if self.throttled_counter < self.max_throttled_counter {
204 self.throttled_counter += 1;
205 return ThrottleResult::Throttled(until - now);
206 } else {
207 return ThrottleResult::AlreadyThrottled;
208 }
209 } else {
210 self.throttled_until = None;
211 self.throttled_counter = 0;
212 }
213 }
214
215 match self.try_acquire_cost(cpu_cost, message_size_cost) {
216 Ok(_) => ThrottleResult::Success,
217 Err(wait_time) => {
218 self.throttled_until = Some(now + wait_time);
219 ThrottleResult::Throttled(wait_time)
220 }
221 }
222 }
223}
224
225impl FromStr for TokenBucket {
226 type Err = String;
227
228 fn from_str(s: &str) -> Result<Self, String> {
229 let fields: Vec<&str> = s.split(',').collect();
230
231 if fields.len() != 5 {
232 return Err(format!(
233 "invalid number of fields, expected = 9, actual = {}",
234 fields.len()
235 ));
236 }
237
238 let mut nums = Vec::new();
239
240 for f in fields {
241 let num = u64::from_str(f)
242 .map_err(|e| format!("failed to parse number: {:?}", e))?;
243 nums.push(num);
244 }
245
246 let mut bucket =
248 TokenBucket::new(nums[0], nums[1], nums[2], nums[3], 1, 1, 1, 0);
249 bucket.set_max_throttled_counter(nums[4]);
250
251 Ok(bucket)
252 }
253}
254
255#[derive(Default, DeriveMallocSizeOf, Clone)]
256pub struct TokenBucketManager {
257 buckets: HashMap<String, Arc<Mutex<TokenBucket>>>,
259}
260
261impl TokenBucketManager {
262 pub fn register(&mut self, name: String, bucket: TokenBucket) {
263 if self.buckets.contains_key(&name) {
264 panic!("token bucket {:?} already registered", name);
265 }
266
267 self.buckets.insert(name, Arc::new(Mutex::new(bucket)));
268 }
269
270 pub fn get(&self, name: &str) -> Option<Arc<Mutex<TokenBucket>>> {
271 self.buckets.get(name).cloned()
272 }
273
274 pub fn load(
275 toml_file: &str, section: Option<&str>,
276 ) -> Result<Self, String> {
277 let content = read_to_string(toml_file)
278 .map_err(|e| format!("failed to read toml file: {:?}", e))?;
279 let toml_val = content
280 .parse::<toml::Value>()
281 .map_err(|e| format!("failed to parse toml file: {:?}", e))?;
282
283 let val = match section {
284 Some(section) => match toml_val.get(section) {
285 Some(val) => val,
286 None => return Err(format!("section [{}] not found", section)),
287 },
288 None => &toml_val,
289 };
290 let table = val.as_table().expect("not table value");
291
292 let mut manager = TokenBucketManager::default();
293
294 for (k, v) in table.iter() {
295 let v = match v.as_str() {
296 Some(v) => v,
297 None => {
298 return Err(format!(
299 "invalid value type {:?}, string type required",
300 v.type_str()
301 ))
302 }
303 };
304
305 manager.register(k.into(), TokenBucket::from_str(v)?);
306 }
307
308 Ok(manager)
309 }
310}
311
312#[derive(Default, DeriveMallocSizeOf)]
313pub struct ThrottledManager<K: Eq + Hash + MallocSizeOf> {
314 items: HashMap<K, Instant>,
315}
316
317impl<K: Eq + Hash + MallocSizeOf> ThrottledManager<K> {
318 pub fn set_throttled(&mut self, k: K, until: Instant) {
319 let current = self.items.entry(k).or_insert(until);
320 if *current < until {
321 *current = until;
322 }
323 }
324
325 pub fn check_throttled(&mut self, k: &K) -> bool {
326 let until = match self.items.get(k) {
327 Some(until) => until,
328 None => return false,
329 };
330
331 if Instant::now() < *until {
332 return true;
333 }
334
335 self.items.remove(k);
336
337 false
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use crate::token_bucket::{ThrottleResult, TokenBucket};
344 use std::{thread::sleep, time::Duration};
345
346 #[test]
347 fn test_init_tokens() {
348 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
350 assert!(
351 bucket.try_acquire_cost(1, 1).unwrap_err()
352 <= Duration::from_secs(1)
353 );
354
355 let mut bucket = TokenBucket::new(3, 1, 1, 1, 3, 1, 1, 1);
357 assert!(
358 bucket.try_acquire_cost(2, 2).unwrap_err()
359 <= Duration::from_secs(1)
360 );
361 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
362 }
363
364 #[test]
365 fn test_acquire() {
366 let mut bucket = TokenBucket::full(3, 1, 1, 3, 1, 1);
367
368 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
370 assert_eq!(bucket.try_acquire_cost(2, 2), Ok(()));
371
372 assert!(
374 bucket.try_acquire_cost(1, 1).unwrap_err()
375 <= Duration::from_secs(1)
376 );
377 assert!(
378 bucket.try_acquire_cost(2, 2).unwrap_err()
379 <= Duration::from_secs(2)
380 );
381
382 sleep(Duration::from_millis(500));
384 assert!(
385 bucket.try_acquire_cost(1, 1).unwrap_err()
386 <= Duration::from_millis(500)
387 );
388
389 sleep(Duration::from_millis(500));
391
392 assert!(
394 bucket.try_acquire_cost(2, 2).unwrap_err()
395 <= Duration::from_secs(1)
396 );
397
398 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
400 }
401
402 fn assert_throttled(result: ThrottleResult, wait_time: Duration) {
403 match result {
404 ThrottleResult::Throttled(d) => assert!(d <= wait_time),
405 _ => panic!("invalid throttle result"),
406 }
407 }
408
409 #[test]
410 fn test_throttled() {
411 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
413
414 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
416
417 assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
419
420 sleep(Duration::from_secs(1));
421
422 assert_eq!(bucket.throttle(1, 1), ThrottleResult::Success);
423 assert_eq!(bucket.throttled_until, None);
424 assert_eq!(bucket.throttled_counter, 0);
425 }
426
427 #[test]
428 fn test_tolerate_throttling() {
429 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
431 bucket.set_max_throttled_counter(2);
432
433 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
435
436 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
438 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
439
440 assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
442 }
443}