throttling/
token_bucket.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5use malloc_size_of::MallocSizeOf;
6use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
7use parking_lot::Mutex;
8use std::{
9    cmp::{max, min},
10    collections::HashMap,
11    fs::read_to_string,
12    hash::Hash,
13    str::FromStr,
14    sync::Arc,
15    time::{Duration, Instant},
16};
17
18#[derive(Debug, Eq, PartialEq)]
19pub enum ThrottleResult {
20    Success,
21    Throttled(Duration),
22    AlreadyThrottled,
23}
24
25#[derive(DeriveMallocSizeOf)]
26pub struct ThrottleTokens {
27    max_tokens: u64,    // maximum tokens allowed in bucket
28    cur_tokens: u64,    // current tokens in bucket
29    recharge_rate: u64, // recharge N tokens per second
30    default_cost: u64,  // default tokens to acquire once
31}
32
33impl ThrottleTokens {
34    pub fn new(
35        max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64,
36    ) -> Self {
37        ThrottleTokens {
38            max_tokens,
39            cur_tokens,
40            recharge_rate,
41            default_cost,
42        }
43    }
44}
45
46#[derive(DeriveMallocSizeOf)]
47pub struct TokenBucket {
48    cpu_tokens: ThrottleTokens,
49    message_size_tokens: ThrottleTokens,
50    last_update: Instant,
51    // once acquire failed, record the next time to acquire tokens
52    throttled_until: Option<Instant>,
53    // client may send multiple requests in a short time, and the
54    // `throttled_counter` is used to tolerate throttling instead
55    // of disconnect the client directly.
56    throttled_counter: u64,
57    max_throttled_counter: u64,
58}
59
60impl TokenBucket {
61    #[allow(clippy::too_many_arguments)]
62    pub fn new(
63        max_cpu_tokens: u64, cur_cpu_tokens: u64, cpu_token_recharge_rate: u64,
64        default_cpu_cost: u64, max_message_tokens: u64,
65        cur_message_tokens: u64, message_token_recharge_rate: u64,
66        default_message_cost: u64,
67    ) -> Self {
68        assert!(cur_cpu_tokens <= max_cpu_tokens);
69        assert!(cur_message_tokens <= max_message_tokens);
70
71        TokenBucket {
72            cpu_tokens: ThrottleTokens::new(
73                max_cpu_tokens,
74                cur_cpu_tokens,
75                cpu_token_recharge_rate,
76                default_cpu_cost,
77            ),
78            message_size_tokens: ThrottleTokens::new(
79                max_message_tokens,
80                cur_message_tokens,
81                message_token_recharge_rate,
82                default_message_cost,
83            ),
84            last_update: Instant::now(),
85            throttled_until: None,
86            throttled_counter: 0,
87            max_throttled_counter: 0,
88        }
89    }
90
91    pub fn full(
92        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
93        default_cpu_cost: u64, max_message_tokens: u64,
94        message_token_recharge_rate: u64, default_message_cost: u64,
95    ) -> Self {
96        Self::new(
97            max_cpu_tokens,
98            max_cpu_tokens,
99            cpu_token_recharge_rate,
100            default_cpu_cost,
101            max_message_tokens,
102            max_message_tokens,
103            message_token_recharge_rate,
104            default_message_cost,
105        )
106    }
107
108    pub fn empty(
109        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
110        default_cpu_cost: u64, max_message_tokens: u64,
111        message_token_recharge_rate: u64, default_message_cost: u64,
112    ) -> Self {
113        Self::new(
114            max_cpu_tokens,
115            0, /* cur_cpu_tokens */
116            cpu_token_recharge_rate,
117            default_cpu_cost,
118            max_message_tokens,
119            0, /* cur_message_tokens */
120            message_token_recharge_rate,
121            default_message_cost,
122        )
123    }
124
125    pub fn set_max_throttled_counter(&mut self, max_throttled_counter: u64) {
126        self.max_throttled_counter = max_throttled_counter;
127    }
128
129    fn refresh(&mut self, now: Instant) {
130        let elapsed_secs = (now - self.last_update).as_secs();
131        if elapsed_secs == 0 {
132            return;
133        }
134
135        let cpu_recharged = self.cpu_tokens.recharge_rate * elapsed_secs;
136        self.cpu_tokens.cur_tokens = min(
137            self.cpu_tokens.max_tokens,
138            self.cpu_tokens.cur_tokens + cpu_recharged,
139        );
140        let message_recharged =
141            self.message_size_tokens.recharge_rate * elapsed_secs;
142        self.message_size_tokens.cur_tokens = min(
143            self.message_size_tokens.max_tokens,
144            self.message_size_tokens.cur_tokens + message_recharged,
145        );
146        self.last_update += Duration::from_secs(elapsed_secs);
147    }
148
149    fn try_acquire_cost(
150        &mut self, cpu_cost: u64, message_size_cost: u64,
151    ) -> Result<(), Duration> {
152        let now = Instant::now();
153
154        self.refresh(now);
155
156        if cpu_cost <= self.cpu_tokens.cur_tokens
157            && message_size_cost <= self.message_size_tokens.cur_tokens
158        {
159            // tokens enough
160            self.cpu_tokens.cur_tokens -= cpu_cost;
161            self.message_size_tokens.cur_tokens -= message_size_cost;
162            return Ok(());
163        }
164
165        // tokens not enough and throttled
166        let cpu_recharge_secs = if cpu_cost > self.cpu_tokens.cur_tokens {
167            ((cpu_cost - self.cpu_tokens.cur_tokens) as f64
168                / self.cpu_tokens.recharge_rate as f64)
169                .ceil() as u64
170        } else {
171            0
172        };
173        let message_recharge_secs = if message_size_cost
174            > self.message_size_tokens.cur_tokens
175        {
176            ((message_size_cost - self.message_size_tokens.cur_tokens) as f64
177                / self.message_size_tokens.recharge_rate as f64)
178                .ceil() as u64
179        } else {
180            0
181        };
182        let recharge_secs = max(cpu_recharge_secs, message_recharge_secs);
183        // `refresh` ensures the difference in `self.last_update` and `now` is
184        // less than 1 second, and `recharge_secs` is at least 1 second,
185        // so this will not underflow.
186        Err(self.last_update + Duration::from_secs(recharge_secs) - now)
187    }
188
189    pub fn throttle_default(&mut self) -> ThrottleResult {
190        self.throttle(
191            self.cpu_tokens.default_cost,
192            self.message_size_tokens.default_cost,
193        )
194    }
195
196    pub fn throttle(
197        &mut self, cpu_cost: u64, message_size_cost: u64,
198    ) -> ThrottleResult {
199        let now = Instant::now();
200
201        // already throttled
202        if let Some(until) = self.throttled_until {
203            if now < until {
204                if self.throttled_counter < self.max_throttled_counter {
205                    self.throttled_counter += 1;
206                    return ThrottleResult::Throttled(until - now);
207                } else {
208                    return ThrottleResult::AlreadyThrottled;
209                }
210            } else {
211                self.throttled_until = None;
212                self.throttled_counter = 0;
213            }
214        }
215
216        match self.try_acquire_cost(cpu_cost, message_size_cost) {
217            Ok(_) => ThrottleResult::Success,
218            Err(wait_time) => {
219                self.throttled_until = Some(now + wait_time);
220                ThrottleResult::Throttled(wait_time)
221            }
222        }
223    }
224}
225
226impl FromStr for TokenBucket {
227    type Err = String;
228
229    fn from_str(s: &str) -> Result<Self, String> {
230        let fields: Vec<&str> = s.split(',').collect();
231
232        if fields.len() != 5 {
233            return Err(format!(
234                "invalid number of fields, expected = 9, actual = {}",
235                fields.len()
236            ));
237        }
238
239        let mut nums = Vec::new();
240
241        for f in fields {
242            let num = u64::from_str(f)
243                .map_err(|e| format!("failed to parse number: {:?}", e))?;
244            nums.push(num);
245        }
246
247        // TODO: Correctly set the message token information.
248        let mut bucket =
249            TokenBucket::new(nums[0], nums[1], nums[2], nums[3], 1, 1, 1, 0);
250        bucket.set_max_throttled_counter(nums[4]);
251
252        Ok(bucket)
253    }
254}
255
256#[derive(Default, DeriveMallocSizeOf, Clone)]
257pub struct TokenBucketManager {
258    // manage buckets by name
259    buckets: HashMap<String, Arc<Mutex<TokenBucket>>>,
260}
261
262impl TokenBucketManager {
263    pub fn register(&mut self, name: String, bucket: TokenBucket) {
264        if self.buckets.contains_key(&name) {
265            panic!("token bucket {:?} already registered", name);
266        }
267
268        self.buckets.insert(name, Arc::new(Mutex::new(bucket)));
269    }
270
271    pub fn get(&self, name: &str) -> Option<Arc<Mutex<TokenBucket>>> {
272        self.buckets.get(name).cloned()
273    }
274
275    pub fn load(
276        toml_file: &str, section: Option<&str>,
277    ) -> Result<Self, String> {
278        let content = read_to_string(toml_file)
279            .map_err(|e| format!("failed to read toml file: {:?}", e))?;
280        let toml_val = content
281            .parse::<toml::Value>()
282            .map_err(|e| format!("failed to parse toml file: {:?}", e))?;
283
284        let val = match section {
285            Some(section) => match toml_val.get(section) {
286                Some(val) => val,
287                None => return Err(format!("section [{}] not found", section)),
288            },
289            None => &toml_val,
290        };
291        let table = val.as_table().expect("not table value");
292
293        let mut manager = TokenBucketManager::default();
294
295        for (k, v) in table.iter() {
296            let v = match v.as_str() {
297                Some(v) => v,
298                None => {
299                    return Err(format!(
300                        "invalid value type {:?}, string type required",
301                        v.type_str()
302                    ))
303                }
304            };
305
306            manager.register(k.into(), TokenBucket::from_str(v)?);
307        }
308
309        Ok(manager)
310    }
311}
312
313#[derive(Default, DeriveMallocSizeOf)]
314pub struct ThrottledManager<K: Eq + Hash + MallocSizeOf> {
315    items: HashMap<K, Instant>,
316}
317
318impl<K: Eq + Hash + MallocSizeOf> ThrottledManager<K> {
319    pub fn set_throttled(&mut self, k: K, until: Instant) {
320        let current = self.items.entry(k).or_insert(until);
321        if *current < until {
322            *current = until;
323        }
324    }
325
326    pub fn check_throttled(&mut self, k: &K) -> bool {
327        let until = match self.items.get(k) {
328            Some(until) => until,
329            None => return false,
330        };
331
332        if Instant::now() < *until {
333            return true;
334        }
335
336        self.items.remove(k);
337
338        false
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use crate::token_bucket::{ThrottleResult, TokenBucket};
345    use std::{thread::sleep, time::Duration};
346
347    #[test]
348    fn test_init_tokens() {
349        // empty bucket
350        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
351        assert!(
352            bucket.try_acquire_cost(1, 1).unwrap_err()
353                <= Duration::from_secs(1)
354        );
355
356        // 1 token
357        let mut bucket = TokenBucket::new(3, 1, 1, 1, 3, 1, 1, 1);
358        assert!(
359            bucket.try_acquire_cost(2, 2).unwrap_err()
360                <= Duration::from_secs(1)
361        );
362        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
363    }
364
365    #[test]
366    fn test_acquire() {
367        let mut bucket = TokenBucket::full(3, 1, 1, 3, 1, 1);
368
369        // Token enough
370        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
371        assert_eq!(bucket.try_acquire_cost(2, 2), Ok(()));
372
373        // Token not enough
374        assert!(
375            bucket.try_acquire_cost(1, 1).unwrap_err()
376                <= Duration::from_secs(1)
377        );
378        assert!(
379            bucket.try_acquire_cost(2, 2).unwrap_err()
380                <= Duration::from_secs(2)
381        );
382
383        // Sleep 0.5s, but not recharged
384        sleep(Duration::from_millis(500));
385        assert!(
386            bucket.try_acquire_cost(1, 1).unwrap_err()
387                <= Duration::from_millis(500)
388        );
389
390        // Sleep 0.5s, and recharged 1 token
391        sleep(Duration::from_millis(500));
392
393        // cannot acquire 2 tokens since only 1 recharged
394        assert!(
395            bucket.try_acquire_cost(2, 2).unwrap_err()
396                <= Duration::from_secs(1)
397        );
398
399        // acquire the recharged 1 token
400        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
401    }
402
403    fn assert_throttled(result: ThrottleResult, wait_time: Duration) {
404        match result {
405            ThrottleResult::Throttled(d) => assert!(d <= wait_time),
406            _ => panic!("invalid throttle result"),
407        }
408    }
409
410    #[test]
411    fn test_throttled() {
412        // empty bucket
413        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
414
415        // throttled
416        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
417
418        // already throttled
419        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
420
421        sleep(Duration::from_secs(1));
422
423        assert_eq!(bucket.throttle(1, 1), ThrottleResult::Success);
424        assert_eq!(bucket.throttled_until, None);
425        assert_eq!(bucket.throttled_counter, 0);
426    }
427
428    #[test]
429    fn test_tolerate_throttling() {
430        // empty bucket
431        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
432        bucket.set_max_throttled_counter(2);
433
434        // throttled
435        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
436
437        // tolerate another 2 times
438        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
439        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
440
441        // already throttled
442        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
443    }
444}