1use malloc_size_of::MallocSizeOf;
6use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
7use parking_lot::Mutex;
8use std::{
9 cmp::{max, min},
10 collections::HashMap,
11 fs::read_to_string,
12 hash::Hash,
13 str::FromStr,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17
18#[derive(Debug, Eq, PartialEq)]
19pub enum ThrottleResult {
20 Success,
21 Throttled(Duration),
22 AlreadyThrottled,
23}
24
25#[derive(DeriveMallocSizeOf)]
26pub struct ThrottleTokens {
27 max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64, }
32
33impl ThrottleTokens {
34 pub fn new(
35 max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64,
36 ) -> Self {
37 ThrottleTokens {
38 max_tokens,
39 cur_tokens,
40 recharge_rate,
41 default_cost,
42 }
43 }
44}
45
46#[derive(DeriveMallocSizeOf)]
47pub struct TokenBucket {
48 cpu_tokens: ThrottleTokens,
49 message_size_tokens: ThrottleTokens,
50 last_update: Instant,
51 throttled_until: Option<Instant>,
53 throttled_counter: u64,
57 max_throttled_counter: u64,
58}
59
60impl TokenBucket {
61 #[allow(clippy::too_many_arguments)]
62 pub fn new(
63 max_cpu_tokens: u64, cur_cpu_tokens: u64, cpu_token_recharge_rate: u64,
64 default_cpu_cost: u64, max_message_tokens: u64,
65 cur_message_tokens: u64, message_token_recharge_rate: u64,
66 default_message_cost: u64,
67 ) -> Self {
68 assert!(cur_cpu_tokens <= max_cpu_tokens);
69 assert!(cur_message_tokens <= max_message_tokens);
70
71 TokenBucket {
72 cpu_tokens: ThrottleTokens::new(
73 max_cpu_tokens,
74 cur_cpu_tokens,
75 cpu_token_recharge_rate,
76 default_cpu_cost,
77 ),
78 message_size_tokens: ThrottleTokens::new(
79 max_message_tokens,
80 cur_message_tokens,
81 message_token_recharge_rate,
82 default_message_cost,
83 ),
84 last_update: Instant::now(),
85 throttled_until: None,
86 throttled_counter: 0,
87 max_throttled_counter: 0,
88 }
89 }
90
91 pub fn full(
92 max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
93 default_cpu_cost: u64, max_message_tokens: u64,
94 message_token_recharge_rate: u64, default_message_cost: u64,
95 ) -> Self {
96 Self::new(
97 max_cpu_tokens,
98 max_cpu_tokens,
99 cpu_token_recharge_rate,
100 default_cpu_cost,
101 max_message_tokens,
102 max_message_tokens,
103 message_token_recharge_rate,
104 default_message_cost,
105 )
106 }
107
108 pub fn empty(
109 max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
110 default_cpu_cost: u64, max_message_tokens: u64,
111 message_token_recharge_rate: u64, default_message_cost: u64,
112 ) -> Self {
113 Self::new(
114 max_cpu_tokens,
115 0, cpu_token_recharge_rate,
117 default_cpu_cost,
118 max_message_tokens,
119 0, message_token_recharge_rate,
121 default_message_cost,
122 )
123 }
124
125 pub fn set_max_throttled_counter(&mut self, max_throttled_counter: u64) {
126 self.max_throttled_counter = max_throttled_counter;
127 }
128
129 fn refresh(&mut self, now: Instant) {
130 let elapsed_secs = (now - self.last_update).as_secs();
131 if elapsed_secs == 0 {
132 return;
133 }
134
135 let cpu_recharged = self.cpu_tokens.recharge_rate * elapsed_secs;
136 self.cpu_tokens.cur_tokens = min(
137 self.cpu_tokens.max_tokens,
138 self.cpu_tokens.cur_tokens + cpu_recharged,
139 );
140 let message_recharged =
141 self.message_size_tokens.recharge_rate * elapsed_secs;
142 self.message_size_tokens.cur_tokens = min(
143 self.message_size_tokens.max_tokens,
144 self.message_size_tokens.cur_tokens + message_recharged,
145 );
146 self.last_update += Duration::from_secs(elapsed_secs);
147 }
148
149 fn try_acquire_cost(
150 &mut self, cpu_cost: u64, message_size_cost: u64,
151 ) -> Result<(), Duration> {
152 let now = Instant::now();
153
154 self.refresh(now);
155
156 if cpu_cost <= self.cpu_tokens.cur_tokens
157 && message_size_cost <= self.message_size_tokens.cur_tokens
158 {
159 self.cpu_tokens.cur_tokens -= cpu_cost;
161 self.message_size_tokens.cur_tokens -= message_size_cost;
162 return Ok(());
163 }
164
165 let cpu_recharge_secs = if cpu_cost > self.cpu_tokens.cur_tokens {
167 ((cpu_cost - self.cpu_tokens.cur_tokens) as f64
168 / self.cpu_tokens.recharge_rate as f64)
169 .ceil() as u64
170 } else {
171 0
172 };
173 let message_recharge_secs = if message_size_cost
174 > self.message_size_tokens.cur_tokens
175 {
176 ((message_size_cost - self.message_size_tokens.cur_tokens) as f64
177 / self.message_size_tokens.recharge_rate as f64)
178 .ceil() as u64
179 } else {
180 0
181 };
182 let recharge_secs = max(cpu_recharge_secs, message_recharge_secs);
183 Err(self.last_update + Duration::from_secs(recharge_secs) - now)
187 }
188
189 pub fn throttle_default(&mut self) -> ThrottleResult {
190 self.throttle(
191 self.cpu_tokens.default_cost,
192 self.message_size_tokens.default_cost,
193 )
194 }
195
196 pub fn throttle(
197 &mut self, cpu_cost: u64, message_size_cost: u64,
198 ) -> ThrottleResult {
199 let now = Instant::now();
200
201 if let Some(until) = self.throttled_until {
203 if now < until {
204 if self.throttled_counter < self.max_throttled_counter {
205 self.throttled_counter += 1;
206 return ThrottleResult::Throttled(until - now);
207 } else {
208 return ThrottleResult::AlreadyThrottled;
209 }
210 } else {
211 self.throttled_until = None;
212 self.throttled_counter = 0;
213 }
214 }
215
216 match self.try_acquire_cost(cpu_cost, message_size_cost) {
217 Ok(_) => ThrottleResult::Success,
218 Err(wait_time) => {
219 self.throttled_until = Some(now + wait_time);
220 ThrottleResult::Throttled(wait_time)
221 }
222 }
223 }
224}
225
226impl FromStr for TokenBucket {
227 type Err = String;
228
229 fn from_str(s: &str) -> Result<Self, String> {
230 let fields: Vec<&str> = s.split(',').collect();
231
232 if fields.len() != 5 {
233 return Err(format!(
234 "invalid number of fields, expected = 9, actual = {}",
235 fields.len()
236 ));
237 }
238
239 let mut nums = Vec::new();
240
241 for f in fields {
242 let num = u64::from_str(f)
243 .map_err(|e| format!("failed to parse number: {:?}", e))?;
244 nums.push(num);
245 }
246
247 let mut bucket =
249 TokenBucket::new(nums[0], nums[1], nums[2], nums[3], 1, 1, 1, 0);
250 bucket.set_max_throttled_counter(nums[4]);
251
252 Ok(bucket)
253 }
254}
255
256#[derive(Default, DeriveMallocSizeOf, Clone)]
257pub struct TokenBucketManager {
258 buckets: HashMap<String, Arc<Mutex<TokenBucket>>>,
260}
261
262impl TokenBucketManager {
263 pub fn register(&mut self, name: String, bucket: TokenBucket) {
264 if self.buckets.contains_key(&name) {
265 panic!("token bucket {:?} already registered", name);
266 }
267
268 self.buckets.insert(name, Arc::new(Mutex::new(bucket)));
269 }
270
271 pub fn get(&self, name: &str) -> Option<Arc<Mutex<TokenBucket>>> {
272 self.buckets.get(name).cloned()
273 }
274
275 pub fn load(
276 toml_file: &str, section: Option<&str>,
277 ) -> Result<Self, String> {
278 let content = read_to_string(toml_file)
279 .map_err(|e| format!("failed to read toml file: {:?}", e))?;
280 let toml_val = content
281 .parse::<toml::Value>()
282 .map_err(|e| format!("failed to parse toml file: {:?}", e))?;
283
284 let val = match section {
285 Some(section) => match toml_val.get(section) {
286 Some(val) => val,
287 None => return Err(format!("section [{}] not found", section)),
288 },
289 None => &toml_val,
290 };
291 let table = val.as_table().expect("not table value");
292
293 let mut manager = TokenBucketManager::default();
294
295 for (k, v) in table.iter() {
296 let v = match v.as_str() {
297 Some(v) => v,
298 None => {
299 return Err(format!(
300 "invalid value type {:?}, string type required",
301 v.type_str()
302 ))
303 }
304 };
305
306 manager.register(k.into(), TokenBucket::from_str(v)?);
307 }
308
309 Ok(manager)
310 }
311}
312
313#[derive(Default, DeriveMallocSizeOf)]
314pub struct ThrottledManager<K: Eq + Hash + MallocSizeOf> {
315 items: HashMap<K, Instant>,
316}
317
318impl<K: Eq + Hash + MallocSizeOf> ThrottledManager<K> {
319 pub fn set_throttled(&mut self, k: K, until: Instant) {
320 let current = self.items.entry(k).or_insert(until);
321 if *current < until {
322 *current = until;
323 }
324 }
325
326 pub fn check_throttled(&mut self, k: &K) -> bool {
327 let until = match self.items.get(k) {
328 Some(until) => until,
329 None => return false,
330 };
331
332 if Instant::now() < *until {
333 return true;
334 }
335
336 self.items.remove(k);
337
338 false
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use crate::token_bucket::{ThrottleResult, TokenBucket};
345 use std::{thread::sleep, time::Duration};
346
347 #[test]
348 fn test_init_tokens() {
349 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
351 assert!(
352 bucket.try_acquire_cost(1, 1).unwrap_err()
353 <= Duration::from_secs(1)
354 );
355
356 let mut bucket = TokenBucket::new(3, 1, 1, 1, 3, 1, 1, 1);
358 assert!(
359 bucket.try_acquire_cost(2, 2).unwrap_err()
360 <= Duration::from_secs(1)
361 );
362 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
363 }
364
365 #[test]
366 fn test_acquire() {
367 let mut bucket = TokenBucket::full(3, 1, 1, 3, 1, 1);
368
369 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
371 assert_eq!(bucket.try_acquire_cost(2, 2), Ok(()));
372
373 assert!(
375 bucket.try_acquire_cost(1, 1).unwrap_err()
376 <= Duration::from_secs(1)
377 );
378 assert!(
379 bucket.try_acquire_cost(2, 2).unwrap_err()
380 <= Duration::from_secs(2)
381 );
382
383 sleep(Duration::from_millis(500));
385 assert!(
386 bucket.try_acquire_cost(1, 1).unwrap_err()
387 <= Duration::from_millis(500)
388 );
389
390 sleep(Duration::from_millis(500));
392
393 assert!(
395 bucket.try_acquire_cost(2, 2).unwrap_err()
396 <= Duration::from_secs(1)
397 );
398
399 assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
401 }
402
403 fn assert_throttled(result: ThrottleResult, wait_time: Duration) {
404 match result {
405 ThrottleResult::Throttled(d) => assert!(d <= wait_time),
406 _ => panic!("invalid throttle result"),
407 }
408 }
409
410 #[test]
411 fn test_throttled() {
412 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
414
415 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
417
418 assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
420
421 sleep(Duration::from_secs(1));
422
423 assert_eq!(bucket.throttle(1, 1), ThrottleResult::Success);
424 assert_eq!(bucket.throttled_until, None);
425 assert_eq!(bucket.throttled_counter, 0);
426 }
427
428 #[test]
429 fn test_tolerate_throttling() {
430 let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
432 bucket.set_max_throttled_counter(2);
433
434 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
436
437 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
439 assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
440
441 assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
443 }
444}