throttling/
token_bucket.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5use malloc_size_of::MallocSizeOf;
6use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
7use parking_lot::Mutex;
8use std::{
9    cmp::{max, min},
10    collections::HashMap,
11    fs::read_to_string,
12    hash::Hash,
13    str::FromStr,
14    sync::Arc,
15    time::{Duration, Instant},
16};
17
18#[derive(Debug, Eq, PartialEq)]
19pub enum ThrottleResult {
20    Success,
21    Throttled(Duration),
22    AlreadyThrottled,
23}
24
25#[derive(DeriveMallocSizeOf)]
26pub struct ThrottleTokens {
27    max_tokens: u64,    // maximum tokens allowed in bucket
28    cur_tokens: u64,    // current tokens in bucket
29    recharge_rate: u64, // recharge N tokens per second
30    default_cost: u64,  // default tokens to acquire once
31}
32
33impl ThrottleTokens {
34    pub fn new(
35        max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64,
36    ) -> Self {
37        ThrottleTokens {
38            max_tokens,
39            cur_tokens,
40            recharge_rate,
41            default_cost,
42        }
43    }
44}
45
46#[derive(DeriveMallocSizeOf)]
47pub struct TokenBucket {
48    cpu_tokens: ThrottleTokens,
49    message_size_tokens: ThrottleTokens,
50    last_update: Instant,
51    // once acquire failed, record the next time to acquire tokens
52    throttled_until: Option<Instant>,
53    // client may send multiple requests in a short time, and the
54    // `throttled_counter` is used to tolerate throttling instead
55    // of disconnect the client directly.
56    throttled_counter: u64,
57    max_throttled_counter: u64,
58}
59
60impl TokenBucket {
61    pub fn new(
62        max_cpu_tokens: u64, cur_cpu_tokens: u64, cpu_token_recharge_rate: u64,
63        default_cpu_cost: u64, max_message_tokens: u64,
64        cur_message_tokens: u64, message_token_recharge_rate: u64,
65        default_message_cost: u64,
66    ) -> Self {
67        assert!(cur_cpu_tokens <= max_cpu_tokens);
68        assert!(cur_message_tokens <= max_message_tokens);
69
70        TokenBucket {
71            cpu_tokens: ThrottleTokens::new(
72                max_cpu_tokens,
73                cur_cpu_tokens,
74                cpu_token_recharge_rate,
75                default_cpu_cost,
76            ),
77            message_size_tokens: ThrottleTokens::new(
78                max_message_tokens,
79                cur_message_tokens,
80                message_token_recharge_rate,
81                default_message_cost,
82            ),
83            last_update: Instant::now(),
84            throttled_until: None,
85            throttled_counter: 0,
86            max_throttled_counter: 0,
87        }
88    }
89
90    pub fn full(
91        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
92        default_cpu_cost: u64, max_message_tokens: u64,
93        message_token_recharge_rate: u64, default_message_cost: u64,
94    ) -> Self {
95        Self::new(
96            max_cpu_tokens,
97            max_cpu_tokens,
98            cpu_token_recharge_rate,
99            default_cpu_cost,
100            max_message_tokens,
101            max_message_tokens,
102            message_token_recharge_rate,
103            default_message_cost,
104        )
105    }
106
107    pub fn empty(
108        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
109        default_cpu_cost: u64, max_message_tokens: u64,
110        message_token_recharge_rate: u64, default_message_cost: u64,
111    ) -> Self {
112        Self::new(
113            max_cpu_tokens,
114            0, /* cur_cpu_tokens */
115            cpu_token_recharge_rate,
116            default_cpu_cost,
117            max_message_tokens,
118            0, /* cur_message_tokens */
119            message_token_recharge_rate,
120            default_message_cost,
121        )
122    }
123
124    pub fn set_max_throttled_counter(&mut self, max_throttled_counter: u64) {
125        self.max_throttled_counter = max_throttled_counter;
126    }
127
128    fn refresh(&mut self, now: Instant) {
129        let elapsed_secs = (now - self.last_update).as_secs();
130        if elapsed_secs == 0 {
131            return;
132        }
133
134        let cpu_recharged = self.cpu_tokens.recharge_rate * elapsed_secs;
135        self.cpu_tokens.cur_tokens = min(
136            self.cpu_tokens.max_tokens,
137            self.cpu_tokens.cur_tokens + cpu_recharged,
138        );
139        let message_recharged =
140            self.message_size_tokens.recharge_rate * elapsed_secs;
141        self.message_size_tokens.cur_tokens = min(
142            self.message_size_tokens.max_tokens,
143            self.message_size_tokens.cur_tokens + message_recharged,
144        );
145        self.last_update += Duration::from_secs(elapsed_secs);
146    }
147
148    fn try_acquire_cost(
149        &mut self, cpu_cost: u64, message_size_cost: u64,
150    ) -> Result<(), Duration> {
151        let now = Instant::now();
152
153        self.refresh(now);
154
155        if cpu_cost <= self.cpu_tokens.cur_tokens
156            && message_size_cost <= self.message_size_tokens.cur_tokens
157        {
158            // tokens enough
159            self.cpu_tokens.cur_tokens -= cpu_cost;
160            self.message_size_tokens.cur_tokens -= message_size_cost;
161            return Ok(());
162        }
163
164        // tokens not enough and throttled
165        let cpu_recharge_secs = if cpu_cost > self.cpu_tokens.cur_tokens {
166            ((cpu_cost - self.cpu_tokens.cur_tokens) as f64
167                / self.cpu_tokens.recharge_rate as f64)
168                .ceil() as u64
169        } else {
170            0
171        };
172        let message_recharge_secs = if message_size_cost
173            > self.message_size_tokens.cur_tokens
174        {
175            ((message_size_cost - self.message_size_tokens.cur_tokens) as f64
176                / self.message_size_tokens.recharge_rate as f64)
177                .ceil() as u64
178        } else {
179            0
180        };
181        let recharge_secs = max(cpu_recharge_secs, message_recharge_secs);
182        // `refresh` ensures the difference in `self.last_update` and `now` is
183        // less than 1 second, and `recharge_secs` is at least 1 second,
184        // so this will not underflow.
185        Err(self.last_update + Duration::from_secs(recharge_secs) - now)
186    }
187
188    pub fn throttle_default(&mut self) -> ThrottleResult {
189        self.throttle(
190            self.cpu_tokens.default_cost,
191            self.message_size_tokens.default_cost,
192        )
193    }
194
195    pub fn throttle(
196        &mut self, cpu_cost: u64, message_size_cost: u64,
197    ) -> ThrottleResult {
198        let now = Instant::now();
199
200        // already throttled
201        if let Some(until) = self.throttled_until {
202            if now < until {
203                if self.throttled_counter < self.max_throttled_counter {
204                    self.throttled_counter += 1;
205                    return ThrottleResult::Throttled(until - now);
206                } else {
207                    return ThrottleResult::AlreadyThrottled;
208                }
209            } else {
210                self.throttled_until = None;
211                self.throttled_counter = 0;
212            }
213        }
214
215        match self.try_acquire_cost(cpu_cost, message_size_cost) {
216            Ok(_) => ThrottleResult::Success,
217            Err(wait_time) => {
218                self.throttled_until = Some(now + wait_time);
219                ThrottleResult::Throttled(wait_time)
220            }
221        }
222    }
223}
224
225impl FromStr for TokenBucket {
226    type Err = String;
227
228    fn from_str(s: &str) -> Result<Self, String> {
229        let fields: Vec<&str> = s.split(',').collect();
230
231        if fields.len() != 5 {
232            return Err(format!(
233                "invalid number of fields, expected = 9, actual = {}",
234                fields.len()
235            ));
236        }
237
238        let mut nums = Vec::new();
239
240        for f in fields {
241            let num = u64::from_str(f)
242                .map_err(|e| format!("failed to parse number: {:?}", e))?;
243            nums.push(num);
244        }
245
246        // TODO: Correctly set the message token information.
247        let mut bucket =
248            TokenBucket::new(nums[0], nums[1], nums[2], nums[3], 1, 1, 1, 0);
249        bucket.set_max_throttled_counter(nums[4]);
250
251        Ok(bucket)
252    }
253}
254
255#[derive(Default, DeriveMallocSizeOf, Clone)]
256pub struct TokenBucketManager {
257    // manage buckets by name
258    buckets: HashMap<String, Arc<Mutex<TokenBucket>>>,
259}
260
261impl TokenBucketManager {
262    pub fn register(&mut self, name: String, bucket: TokenBucket) {
263        if self.buckets.contains_key(&name) {
264            panic!("token bucket {:?} already registered", name);
265        }
266
267        self.buckets.insert(name, Arc::new(Mutex::new(bucket)));
268    }
269
270    pub fn get(&self, name: &str) -> Option<Arc<Mutex<TokenBucket>>> {
271        self.buckets.get(name).cloned()
272    }
273
274    pub fn load(
275        toml_file: &str, section: Option<&str>,
276    ) -> Result<Self, String> {
277        let content = read_to_string(toml_file)
278            .map_err(|e| format!("failed to read toml file: {:?}", e))?;
279        let toml_val = content
280            .parse::<toml::Value>()
281            .map_err(|e| format!("failed to parse toml file: {:?}", e))?;
282
283        let val = match section {
284            Some(section) => match toml_val.get(section) {
285                Some(val) => val,
286                None => return Err(format!("section [{}] not found", section)),
287            },
288            None => &toml_val,
289        };
290        let table = val.as_table().expect("not table value");
291
292        let mut manager = TokenBucketManager::default();
293
294        for (k, v) in table.iter() {
295            let v = match v.as_str() {
296                Some(v) => v,
297                None => {
298                    return Err(format!(
299                        "invalid value type {:?}, string type required",
300                        v.type_str()
301                    ))
302                }
303            };
304
305            manager.register(k.into(), TokenBucket::from_str(v)?);
306        }
307
308        Ok(manager)
309    }
310}
311
312#[derive(Default, DeriveMallocSizeOf)]
313pub struct ThrottledManager<K: Eq + Hash + MallocSizeOf> {
314    items: HashMap<K, Instant>,
315}
316
317impl<K: Eq + Hash + MallocSizeOf> ThrottledManager<K> {
318    pub fn set_throttled(&mut self, k: K, until: Instant) {
319        let current = self.items.entry(k).or_insert(until);
320        if *current < until {
321            *current = until;
322        }
323    }
324
325    pub fn check_throttled(&mut self, k: &K) -> bool {
326        let until = match self.items.get(k) {
327            Some(until) => until,
328            None => return false,
329        };
330
331        if Instant::now() < *until {
332            return true;
333        }
334
335        self.items.remove(k);
336
337        false
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use crate::token_bucket::{ThrottleResult, TokenBucket};
344    use std::{thread::sleep, time::Duration};
345
346    #[test]
347    fn test_init_tokens() {
348        // empty bucket
349        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
350        assert!(
351            bucket.try_acquire_cost(1, 1).unwrap_err()
352                <= Duration::from_secs(1)
353        );
354
355        // 1 token
356        let mut bucket = TokenBucket::new(3, 1, 1, 1, 3, 1, 1, 1);
357        assert!(
358            bucket.try_acquire_cost(2, 2).unwrap_err()
359                <= Duration::from_secs(1)
360        );
361        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
362    }
363
364    #[test]
365    fn test_acquire() {
366        let mut bucket = TokenBucket::full(3, 1, 1, 3, 1, 1);
367
368        // Token enough
369        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
370        assert_eq!(bucket.try_acquire_cost(2, 2), Ok(()));
371
372        // Token not enough
373        assert!(
374            bucket.try_acquire_cost(1, 1).unwrap_err()
375                <= Duration::from_secs(1)
376        );
377        assert!(
378            bucket.try_acquire_cost(2, 2).unwrap_err()
379                <= Duration::from_secs(2)
380        );
381
382        // Sleep 0.5s, but not recharged
383        sleep(Duration::from_millis(500));
384        assert!(
385            bucket.try_acquire_cost(1, 1).unwrap_err()
386                <= Duration::from_millis(500)
387        );
388
389        // Sleep 0.5s, and recharged 1 token
390        sleep(Duration::from_millis(500));
391
392        // cannot acquire 2 tokens since only 1 recharged
393        assert!(
394            bucket.try_acquire_cost(2, 2).unwrap_err()
395                <= Duration::from_secs(1)
396        );
397
398        // acquire the recharged 1 token
399        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
400    }
401
402    fn assert_throttled(result: ThrottleResult, wait_time: Duration) {
403        match result {
404            ThrottleResult::Throttled(d) => assert!(d <= wait_time),
405            _ => panic!("invalid throttle result"),
406        }
407    }
408
409    #[test]
410    fn test_throttled() {
411        // empty bucket
412        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
413
414        // throttled
415        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
416
417        // already throttled
418        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
419
420        sleep(Duration::from_secs(1));
421
422        assert_eq!(bucket.throttle(1, 1), ThrottleResult::Success);
423        assert_eq!(bucket.throttled_until, None);
424        assert_eq!(bucket.throttled_counter, 0);
425    }
426
427    #[test]
428    fn test_tolerate_throttling() {
429        // empty bucket
430        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
431        bucket.set_max_throttled_counter(2);
432
433        // throttled
434        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
435
436        // tolerate another 2 times
437        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
438        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
439
440        // already throttled
441        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
442    }
443}