use malloc_size_of::MallocSizeOf;
use malloc_size_of_derive::MallocSizeOf as DeriveMallocSizeOf;
use parking_lot::Mutex;
use std::{
    cmp::{max, min},
    collections::HashMap,
    fs::read_to_string,
    hash::Hash,
    str::FromStr,
    sync::Arc,
    time::{Duration, Instant},
};
#[derive(Debug, Eq, PartialEq)]
pub enum ThrottleResult {
    Success,
    Throttled(Duration),
    AlreadyThrottled,
}
#[derive(DeriveMallocSizeOf)]
pub struct ThrottleTokens {
    max_tokens: u64,    cur_tokens: u64,    recharge_rate: u64, default_cost: u64,  }
impl ThrottleTokens {
    pub fn new(
        max_tokens: u64, cur_tokens: u64, recharge_rate: u64, default_cost: u64,
    ) -> Self {
        ThrottleTokens {
            max_tokens,
            cur_tokens,
            recharge_rate,
            default_cost,
        }
    }
}
#[derive(DeriveMallocSizeOf)]
pub struct TokenBucket {
    cpu_tokens: ThrottleTokens,
    message_size_tokens: ThrottleTokens,
    last_update: Instant,
    throttled_until: Option<Instant>,
    throttled_counter: u64,
    max_throttled_counter: u64,
}
impl TokenBucket {
    pub fn new(
        max_cpu_tokens: u64, cur_cpu_tokens: u64, cpu_token_recharge_rate: u64,
        default_cpu_cost: u64, max_message_tokens: u64,
        cur_message_tokens: u64, message_token_recharge_rate: u64,
        default_message_cost: u64,
    ) -> Self {
        assert!(cur_cpu_tokens <= max_cpu_tokens);
        assert!(cur_message_tokens <= max_message_tokens);
        TokenBucket {
            cpu_tokens: ThrottleTokens::new(
                max_cpu_tokens,
                cur_cpu_tokens,
                cpu_token_recharge_rate,
                default_cpu_cost,
            ),
            message_size_tokens: ThrottleTokens::new(
                max_message_tokens,
                cur_message_tokens,
                message_token_recharge_rate,
                default_message_cost,
            ),
            last_update: Instant::now(),
            throttled_until: None,
            throttled_counter: 0,
            max_throttled_counter: 0,
        }
    }
    pub fn full(
        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
        default_cpu_cost: u64, max_message_tokens: u64,
        message_token_recharge_rate: u64, default_message_cost: u64,
    ) -> Self {
        Self::new(
            max_cpu_tokens,
            max_cpu_tokens,
            cpu_token_recharge_rate,
            default_cpu_cost,
            max_message_tokens,
            max_message_tokens,
            message_token_recharge_rate,
            default_message_cost,
        )
    }
    pub fn empty(
        max_cpu_tokens: u64, cpu_token_recharge_rate: u64,
        default_cpu_cost: u64, max_message_tokens: u64,
        message_token_recharge_rate: u64, default_message_cost: u64,
    ) -> Self {
        Self::new(
            max_cpu_tokens,
            0, cpu_token_recharge_rate,
            default_cpu_cost,
            max_message_tokens,
            0, message_token_recharge_rate,
            default_message_cost,
        )
    }
    pub fn set_max_throttled_counter(&mut self, max_throttled_counter: u64) {
        self.max_throttled_counter = max_throttled_counter;
    }
    fn refresh(&mut self, now: Instant) {
        let elapsed_secs = (now - self.last_update).as_secs();
        if elapsed_secs == 0 {
            return;
        }
        let cpu_recharged = self.cpu_tokens.recharge_rate * elapsed_secs;
        self.cpu_tokens.cur_tokens = min(
            self.cpu_tokens.max_tokens,
            self.cpu_tokens.cur_tokens + cpu_recharged,
        );
        let message_recharged =
            self.message_size_tokens.recharge_rate * elapsed_secs;
        self.message_size_tokens.cur_tokens = min(
            self.message_size_tokens.max_tokens,
            self.message_size_tokens.cur_tokens + message_recharged,
        );
        self.last_update += Duration::from_secs(elapsed_secs);
    }
    fn try_acquire_cost(
        &mut self, cpu_cost: u64, message_size_cost: u64,
    ) -> Result<(), Duration> {
        let now = Instant::now();
        self.refresh(now);
        if cpu_cost <= self.cpu_tokens.cur_tokens
            && message_size_cost <= self.message_size_tokens.cur_tokens
        {
            self.cpu_tokens.cur_tokens -= cpu_cost;
            self.message_size_tokens.cur_tokens -= message_size_cost;
            return Ok(());
        }
        let cpu_recharge_secs = if cpu_cost > self.cpu_tokens.cur_tokens {
            ((cpu_cost - self.cpu_tokens.cur_tokens) as f64
                / self.cpu_tokens.recharge_rate as f64)
                .ceil() as u64
        } else {
            0
        };
        let message_recharge_secs = if message_size_cost
            > self.message_size_tokens.cur_tokens
        {
            ((message_size_cost - self.message_size_tokens.cur_tokens) as f64
                / self.message_size_tokens.recharge_rate as f64)
                .ceil() as u64
        } else {
            0
        };
        let recharge_secs = max(cpu_recharge_secs, message_recharge_secs);
        Err(self.last_update + Duration::from_secs(recharge_secs) - now)
    }
    pub fn throttle_default(&mut self) -> ThrottleResult {
        self.throttle(
            self.cpu_tokens.default_cost,
            self.message_size_tokens.default_cost,
        )
    }
    pub fn throttle(
        &mut self, cpu_cost: u64, message_size_cost: u64,
    ) -> ThrottleResult {
        let now = Instant::now();
        if let Some(until) = self.throttled_until {
            if now < until {
                if self.throttled_counter < self.max_throttled_counter {
                    self.throttled_counter += 1;
                    return ThrottleResult::Throttled(until - now);
                } else {
                    return ThrottleResult::AlreadyThrottled;
                }
            } else {
                self.throttled_until = None;
                self.throttled_counter = 0;
            }
        }
        match self.try_acquire_cost(cpu_cost, message_size_cost) {
            Ok(_) => ThrottleResult::Success,
            Err(wait_time) => {
                self.throttled_until = Some(now + wait_time);
                ThrottleResult::Throttled(wait_time)
            }
        }
    }
}
impl FromStr for TokenBucket {
    type Err = String;
    fn from_str(s: &str) -> Result<Self, String> {
        let fields: Vec<&str> = s.split(',').collect();
        if fields.len() != 5 {
            return Err(format!(
                "invalid number of fields, expected = 9, actual = {}",
                fields.len()
            ));
        }
        let mut nums = Vec::new();
        for f in fields {
            let num = u64::from_str(f)
                .map_err(|e| format!("failed to parse number: {:?}", e))?;
            nums.push(num);
        }
        let mut bucket =
            TokenBucket::new(nums[0], nums[1], nums[2], nums[3], 1, 1, 1, 0);
        bucket.set_max_throttled_counter(nums[4]);
        Ok(bucket)
    }
}
#[derive(Default, DeriveMallocSizeOf, Clone)]
pub struct TokenBucketManager {
    buckets: HashMap<String, Arc<Mutex<TokenBucket>>>,
}
impl TokenBucketManager {
    pub fn register(&mut self, name: String, bucket: TokenBucket) {
        if self.buckets.contains_key(&name) {
            panic!("token bucket {:?} already registered", name);
        }
        self.buckets.insert(name, Arc::new(Mutex::new(bucket)));
    }
    pub fn get(&self, name: &str) -> Option<Arc<Mutex<TokenBucket>>> {
        self.buckets.get(name).cloned()
    }
    pub fn load(
        toml_file: &str, section: Option<&str>,
    ) -> Result<Self, String> {
        let content = read_to_string(toml_file)
            .map_err(|e| format!("failed to read toml file: {:?}", e))?;
        let toml_val = content
            .parse::<toml::Value>()
            .map_err(|e| format!("failed to parse toml file: {:?}", e))?;
        let val = match section {
            Some(section) => match toml_val.get(section) {
                Some(val) => val,
                None => return Err(format!("section [{}] not found", section)),
            },
            None => &toml_val,
        };
        let table = val.as_table().expect("not table value");
        let mut manager = TokenBucketManager::default();
        for (k, v) in table.iter() {
            let v = match v.as_str() {
                Some(v) => v,
                None => {
                    return Err(format!(
                        "invalid value type {:?}, string type required",
                        v.type_str()
                    ))
                }
            };
            manager.register(k.into(), TokenBucket::from_str(v)?);
        }
        Ok(manager)
    }
}
#[derive(Default, DeriveMallocSizeOf)]
pub struct ThrottledManager<K: Eq + Hash + MallocSizeOf> {
    items: HashMap<K, Instant>,
}
impl<K: Eq + Hash + MallocSizeOf> ThrottledManager<K> {
    pub fn set_throttled(&mut self, k: K, until: Instant) {
        let current = self.items.entry(k).or_insert(until);
        if *current < until {
            *current = until;
        }
    }
    pub fn check_throttled(&mut self, k: &K) -> bool {
        let until = match self.items.get(k) {
            Some(until) => until,
            None => return false,
        };
        if Instant::now() < *until {
            return true;
        }
        self.items.remove(k);
        false
    }
}
#[cfg(test)]
mod tests {
    use crate::token_bucket::{ThrottleResult, TokenBucket};
    use std::{thread::sleep, time::Duration};
    #[test]
    fn test_init_tokens() {
        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
        assert!(
            bucket.try_acquire_cost(1, 1).unwrap_err()
                <= Duration::from_secs(1)
        );
        let mut bucket = TokenBucket::new(3, 1, 1, 1, 3, 1, 1, 1);
        assert!(
            bucket.try_acquire_cost(2, 2).unwrap_err()
                <= Duration::from_secs(1)
        );
        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
    }
    #[test]
    fn test_acquire() {
        let mut bucket = TokenBucket::full(3, 1, 1, 3, 1, 1);
        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
        assert_eq!(bucket.try_acquire_cost(2, 2), Ok(()));
        assert!(
            bucket.try_acquire_cost(1, 1).unwrap_err()
                <= Duration::from_secs(1)
        );
        assert!(
            bucket.try_acquire_cost(2, 2).unwrap_err()
                <= Duration::from_secs(2)
        );
        sleep(Duration::from_millis(500));
        assert!(
            bucket.try_acquire_cost(1, 1).unwrap_err()
                <= Duration::from_millis(500)
        );
        sleep(Duration::from_millis(500));
        assert!(
            bucket.try_acquire_cost(2, 2).unwrap_err()
                <= Duration::from_secs(1)
        );
        assert_eq!(bucket.try_acquire_cost(1, 1), Ok(()));
    }
    fn assert_throttled(result: ThrottleResult, wait_time: Duration) {
        match result {
            ThrottleResult::Throttled(d) => assert!(d <= wait_time),
            _ => panic!("invalid throttle result"),
        }
    }
    #[test]
    fn test_throttled() {
        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
        sleep(Duration::from_secs(1));
        assert_eq!(bucket.throttle(1, 1), ThrottleResult::Success);
        assert_eq!(bucket.throttled_until, None);
        assert_eq!(bucket.throttled_counter, 0);
    }
    #[test]
    fn test_tolerate_throttling() {
        let mut bucket = TokenBucket::empty(3, 1, 1, 3, 1, 1);
        bucket.set_max_throttled_counter(2);
        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
        assert_throttled(bucket.throttle(1, 1), Duration::from_secs(1));
        assert_eq!(bucket.throttle(1, 1), ThrottleResult::AlreadyThrottled);
    }
}