network/ip/
sessions_limit.rs

1use crate::ip::util::SubnetType;
2use std::{
3    collections::HashMap, convert::TryFrom, hash::Hash, net::IpAddr,
4    str::FromStr,
5};
6
7/// SessionIpLimit is used to limits the number of sessions for a single IP
8/// address or subnet.
9pub trait SessionIpLimit: Send + Sync {
10    fn contains(&self, _ip: &IpAddr) -> bool { true }
11    fn is_allowed(&self, _ip: &IpAddr) -> bool { true }
12    fn add(&mut self, _ip: IpAddr) -> bool { true }
13    fn remove(&mut self, _ip: &IpAddr) -> bool { true }
14}
15
16#[derive(Default, Debug, Clone, PartialEq)]
17pub struct SessionIpLimitConfig {
18    single_ip_quota: usize,
19    subnet_a_quota: usize,
20    subnet_b_quota: usize,
21    subnet_c_quota: usize,
22}
23
24impl TryFrom<String> for SessionIpLimitConfig {
25    type Error = String;
26
27    fn try_from(value: String) -> Result<Self, String> {
28        let configs: Vec<&str> = value.split(',').collect();
29
30        let mut nums = Vec::new();
31        for s in configs {
32            let num = usize::from_str(s)
33                .map_err(|e| format!("failed to parse number: {:?}", e))?;
34            nums.push(num);
35        }
36
37        if nums.len() != 4 {
38            return Err(format!(
39                "invalid number of fields, expected = 4, actual = {}",
40                nums.len()
41            ));
42        }
43
44        Ok(SessionIpLimitConfig {
45            single_ip_quota: nums[0],
46            subnet_a_quota: nums[1],
47            subnet_b_quota: nums[2],
48            subnet_c_quota: nums[3],
49        })
50    }
51}
52
53/// Creates a SessionIpLimit instance with specified IP quotas. The
54/// `subnet_quotas` represents subnet-a (ip/8), subnet-b (ip/16) and subnet-c
55/// (ip/24) respectively.
56pub fn new_session_ip_limit(
57    config: &SessionIpLimitConfig,
58) -> Box<dyn SessionIpLimit> {
59    let mut limits: Vec<Box<dyn SessionIpLimit>> = Vec::new();
60
61    if config.single_ip_quota > 0 {
62        limits.push(Box::new(SingleIpLimit::new(config.single_ip_quota)));
63    }
64
65    if config.subnet_a_quota > 0 {
66        limits.push(Box::new(SubnetLimit::new(
67            config.subnet_a_quota,
68            SubnetType::A,
69        )));
70    }
71
72    if config.subnet_b_quota > 0 {
73        limits.push(Box::new(SubnetLimit::new(
74            config.subnet_b_quota,
75            SubnetType::B,
76        )));
77    }
78
79    if config.subnet_c_quota > 0 {
80        limits.push(Box::new(SubnetLimit::new(
81            config.subnet_c_quota,
82            SubnetType::C,
83        )));
84    }
85
86    if limits.is_empty() {
87        Box::new(NoopLimit)
88    } else {
89        Box::new(CompositeLimit::new(limits))
90    }
91}
92
93struct NoopLimit;
94impl SessionIpLimit for NoopLimit {}
95
96struct GenericLimit<T> {
97    quota: usize,
98    items: HashMap<T, usize>,
99}
100
101impl<T: Hash + Eq> GenericLimit<T> {
102    fn new(quota: usize) -> Self {
103        assert!(quota > 0);
104
105        GenericLimit {
106            quota,
107            items: HashMap::new(),
108        }
109    }
110
111    fn contains(&self, key: &T) -> bool { self.items.contains_key(key) }
112
113    fn is_allowed(&self, key: &T) -> bool {
114        match self.items.get(key) {
115            Some(num) => *num < self.quota,
116            None => true,
117        }
118    }
119
120    fn add(&mut self, key: T) -> bool {
121        match self.items.get_mut(&key) {
122            Some(num) => {
123                if *num < self.quota {
124                    *num += 1;
125                    true
126                } else {
127                    false
128                }
129            }
130            None => {
131                self.items.insert(key, 1);
132                true
133            }
134        }
135    }
136
137    fn remove(&mut self, key: &T) -> bool {
138        let num = match self.items.get_mut(key) {
139            Some(num) => num,
140            None => return false,
141        };
142
143        if *num > 1 {
144            *num -= 1;
145        } else {
146            self.items.remove(key);
147        }
148
149        true
150    }
151}
152
153struct SingleIpLimit {
154    inner: GenericLimit<IpAddr>,
155}
156
157impl SingleIpLimit {
158    fn new(quota: usize) -> Self {
159        SingleIpLimit {
160            inner: GenericLimit::new(quota),
161        }
162    }
163}
164
165impl SessionIpLimit for SingleIpLimit {
166    fn contains(&self, ip: &IpAddr) -> bool { self.inner.contains(ip) }
167
168    fn is_allowed(&self, ip: &IpAddr) -> bool { self.inner.is_allowed(ip) }
169
170    fn add(&mut self, ip: IpAddr) -> bool { self.inner.add(ip) }
171
172    fn remove(&mut self, ip: &IpAddr) -> bool { self.inner.remove(ip) }
173}
174
175struct SubnetLimit {
176    inner: GenericLimit<u32>,
177    subnet_type: SubnetType,
178}
179
180impl SubnetLimit {
181    fn new(quota: usize, subnet_type: SubnetType) -> Self {
182        SubnetLimit {
183            inner: GenericLimit::new(quota),
184            subnet_type,
185        }
186    }
187}
188
189impl SessionIpLimit for SubnetLimit {
190    fn contains(&self, ip: &IpAddr) -> bool {
191        let subnet = self.subnet_type.subnet(ip);
192        self.inner.contains(&subnet)
193    }
194
195    fn is_allowed(&self, ip: &IpAddr) -> bool {
196        let subnet = self.subnet_type.subnet(ip);
197        self.inner.is_allowed(&subnet)
198    }
199
200    fn add(&mut self, ip: IpAddr) -> bool {
201        let subnet = self.subnet_type.subnet(&ip);
202        self.inner.add(subnet)
203    }
204
205    fn remove(&mut self, ip: &IpAddr) -> bool {
206        let subnet = self.subnet_type.subnet(ip);
207        self.inner.remove(&subnet)
208    }
209}
210
211struct CompositeLimit {
212    limits: Vec<Box<dyn SessionIpLimit>>,
213}
214
215impl CompositeLimit {
216    fn new(limits: Vec<Box<dyn SessionIpLimit>>) -> Self {
217        CompositeLimit { limits }
218    }
219}
220
221impl SessionIpLimit for CompositeLimit {
222    fn is_allowed(&self, ip: &IpAddr) -> bool {
223        self.limits.iter().all(|l| l.is_allowed(ip))
224    }
225
226    fn add(&mut self, ip: IpAddr) -> bool {
227        if !self.is_allowed(&ip) {
228            return false;
229        }
230
231        for limit in self.limits.iter_mut() {
232            assert!(limit.add(ip));
233        }
234
235        true
236    }
237
238    fn remove(&mut self, ip: &IpAddr) -> bool {
239        if self.limits.iter().any(|l| !l.contains(ip)) {
240            return false;
241        }
242
243        for limit in self.limits.iter_mut() {
244            assert!(limit.remove(ip));
245        }
246
247        true
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::{new_session_ip_limit, SessionIpLimit};
254    use std::{convert::TryInto, net::IpAddr, str::FromStr};
255
256    fn new_ip(ip: &'static str) -> IpAddr { IpAddr::from_str(ip).unwrap() }
257
258    fn new_limit(config: &str) -> Box<dyn SessionIpLimit> {
259        let config: String = config.into();
260        new_session_ip_limit(&config.try_into().unwrap())
261    }
262
263    #[test]
264    fn test_noop() {
265        let mut limit = new_limit("0,0,0,0");
266        assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), true);
267        assert_eq!(limit.add(new_ip("127.0.0.1")), true);
268        assert_eq!(limit.remove(&new_ip("127.0.0.2")), true);
269    }
270
271    #[test]
272    fn test_single_ip() {
273        let mut limit = new_limit("1,0,0,0");
274
275        assert_eq!(limit.remove(&new_ip("127.0.0.1")), false);
276
277        assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), true);
278        assert_eq!(limit.add(new_ip("127.0.0.1")), true);
279
280        assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), false);
281        assert_eq!(limit.add(new_ip("127.0.0.1")), false);
282
283        assert_eq!(limit.is_allowed(&new_ip("127.0.0.2")), true);
284        assert_eq!(limit.add(new_ip("127.0.0.2")), true);
285    }
286
287    #[test]
288    fn test_subnet_all() {
289        let mut limit = new_limit("0,3,2,1");
290
291        assert_eq!(limit.add(new_ip("127.0.0.1")), true);
292
293        // subnet c
294        assert_eq!(limit.add(new_ip("127.0.0.2")), false);
295
296        // subnet b
297        assert_eq!(limit.add(new_ip("127.0.1.1")), true);
298        assert_eq!(limit.add(new_ip("127.0.1.1")), false);
299        assert_eq!(limit.add(new_ip("127.0.1.2")), false);
300        assert_eq!(limit.add(new_ip("127.0.2.1")), false);
301
302        // subnet a
303        assert_eq!(limit.add(new_ip("192.168.0.1")), true);
304        assert_eq!(limit.add(new_ip("192.169.0.1")), true);
305        assert_eq!(limit.add(new_ip("192.170.0.1")), true);
306        assert_eq!(limit.add(new_ip("192.171.0.1")), false);
307    }
308
309    #[test]
310    fn test_subnet_b() {
311        let mut limit = new_limit("0,0,2,0");
312
313        assert_eq!(limit.add(new_ip("127.0.0.1")), true);
314        assert_eq!(limit.add(new_ip("127.0.0.2")), true);
315        assert_eq!(limit.add(new_ip("127.0.0.3")), false);
316        assert_eq!(limit.add(new_ip("127.0.1.1")), false);
317        assert_eq!(limit.add(new_ip("127.1.0.1")), true);
318        assert_eq!(limit.add(new_ip("127.2.0.1")), true);
319        assert_eq!(limit.add(new_ip("127.3.0.1")), true);
320    }
321}