1use crate::ip::util::SubnetType;
2use std::{
3 collections::HashMap, convert::TryFrom, hash::Hash, net::IpAddr,
4 str::FromStr,
5};
6
7pub trait SessionIpLimit: Send + Sync {
10 fn contains(&self, _ip: &IpAddr) -> bool { true }
11 fn is_allowed(&self, _ip: &IpAddr) -> bool { true }
12 fn add(&mut self, _ip: IpAddr) -> bool { true }
13 fn remove(&mut self, _ip: &IpAddr) -> bool { true }
14}
15
16#[derive(Default, Debug, Clone, PartialEq)]
17pub struct SessionIpLimitConfig {
18 single_ip_quota: usize,
19 subnet_a_quota: usize,
20 subnet_b_quota: usize,
21 subnet_c_quota: usize,
22}
23
24impl TryFrom<String> for SessionIpLimitConfig {
25 type Error = String;
26
27 fn try_from(value: String) -> Result<Self, String> {
28 let configs: Vec<&str> = value.split(',').collect();
29
30 let mut nums = Vec::new();
31 for s in configs {
32 let num = usize::from_str(s)
33 .map_err(|e| format!("failed to parse number: {:?}", e))?;
34 nums.push(num);
35 }
36
37 if nums.len() != 4 {
38 return Err(format!(
39 "invalid number of fields, expected = 4, actual = {}",
40 nums.len()
41 ));
42 }
43
44 Ok(SessionIpLimitConfig {
45 single_ip_quota: nums[0],
46 subnet_a_quota: nums[1],
47 subnet_b_quota: nums[2],
48 subnet_c_quota: nums[3],
49 })
50 }
51}
52
53pub fn new_session_ip_limit(
57 config: &SessionIpLimitConfig,
58) -> Box<dyn SessionIpLimit> {
59 let mut limits: Vec<Box<dyn SessionIpLimit>> = Vec::new();
60
61 if config.single_ip_quota > 0 {
62 limits.push(Box::new(SingleIpLimit::new(config.single_ip_quota)));
63 }
64
65 if config.subnet_a_quota > 0 {
66 limits.push(Box::new(SubnetLimit::new(
67 config.subnet_a_quota,
68 SubnetType::A,
69 )));
70 }
71
72 if config.subnet_b_quota > 0 {
73 limits.push(Box::new(SubnetLimit::new(
74 config.subnet_b_quota,
75 SubnetType::B,
76 )));
77 }
78
79 if config.subnet_c_quota > 0 {
80 limits.push(Box::new(SubnetLimit::new(
81 config.subnet_c_quota,
82 SubnetType::C,
83 )));
84 }
85
86 if limits.is_empty() {
87 Box::new(NoopLimit)
88 } else {
89 Box::new(CompositeLimit::new(limits))
90 }
91}
92
93struct NoopLimit;
94impl SessionIpLimit for NoopLimit {}
95
96struct GenericLimit<T> {
97 quota: usize,
98 items: HashMap<T, usize>,
99}
100
101impl<T: Hash + Eq> GenericLimit<T> {
102 fn new(quota: usize) -> Self {
103 assert!(quota > 0);
104
105 GenericLimit {
106 quota,
107 items: HashMap::new(),
108 }
109 }
110
111 fn contains(&self, key: &T) -> bool { self.items.contains_key(key) }
112
113 fn is_allowed(&self, key: &T) -> bool {
114 match self.items.get(key) {
115 Some(num) => *num < self.quota,
116 None => true,
117 }
118 }
119
120 fn add(&mut self, key: T) -> bool {
121 match self.items.get_mut(&key) {
122 Some(num) => {
123 if *num < self.quota {
124 *num += 1;
125 true
126 } else {
127 false
128 }
129 }
130 None => {
131 self.items.insert(key, 1);
132 true
133 }
134 }
135 }
136
137 fn remove(&mut self, key: &T) -> bool {
138 let num = match self.items.get_mut(key) {
139 Some(num) => num,
140 None => return false,
141 };
142
143 if *num > 1 {
144 *num -= 1;
145 } else {
146 self.items.remove(key);
147 }
148
149 true
150 }
151}
152
153struct SingleIpLimit {
154 inner: GenericLimit<IpAddr>,
155}
156
157impl SingleIpLimit {
158 fn new(quota: usize) -> Self {
159 SingleIpLimit {
160 inner: GenericLimit::new(quota),
161 }
162 }
163}
164
165impl SessionIpLimit for SingleIpLimit {
166 fn contains(&self, ip: &IpAddr) -> bool { self.inner.contains(ip) }
167
168 fn is_allowed(&self, ip: &IpAddr) -> bool { self.inner.is_allowed(ip) }
169
170 fn add(&mut self, ip: IpAddr) -> bool { self.inner.add(ip) }
171
172 fn remove(&mut self, ip: &IpAddr) -> bool { self.inner.remove(ip) }
173}
174
175struct SubnetLimit {
176 inner: GenericLimit<u32>,
177 subnet_type: SubnetType,
178}
179
180impl SubnetLimit {
181 fn new(quota: usize, subnet_type: SubnetType) -> Self {
182 SubnetLimit {
183 inner: GenericLimit::new(quota),
184 subnet_type,
185 }
186 }
187}
188
189impl SessionIpLimit for SubnetLimit {
190 fn contains(&self, ip: &IpAddr) -> bool {
191 let subnet = self.subnet_type.subnet(ip);
192 self.inner.contains(&subnet)
193 }
194
195 fn is_allowed(&self, ip: &IpAddr) -> bool {
196 let subnet = self.subnet_type.subnet(ip);
197 self.inner.is_allowed(&subnet)
198 }
199
200 fn add(&mut self, ip: IpAddr) -> bool {
201 let subnet = self.subnet_type.subnet(&ip);
202 self.inner.add(subnet)
203 }
204
205 fn remove(&mut self, ip: &IpAddr) -> bool {
206 let subnet = self.subnet_type.subnet(ip);
207 self.inner.remove(&subnet)
208 }
209}
210
211struct CompositeLimit {
212 limits: Vec<Box<dyn SessionIpLimit>>,
213}
214
215impl CompositeLimit {
216 fn new(limits: Vec<Box<dyn SessionIpLimit>>) -> Self {
217 CompositeLimit { limits }
218 }
219}
220
221impl SessionIpLimit for CompositeLimit {
222 fn is_allowed(&self, ip: &IpAddr) -> bool {
223 self.limits.iter().all(|l| l.is_allowed(ip))
224 }
225
226 fn add(&mut self, ip: IpAddr) -> bool {
227 if !self.is_allowed(&ip) {
228 return false;
229 }
230
231 for limit in self.limits.iter_mut() {
232 assert!(limit.add(ip));
233 }
234
235 true
236 }
237
238 fn remove(&mut self, ip: &IpAddr) -> bool {
239 if self.limits.iter().any(|l| !l.contains(ip)) {
240 return false;
241 }
242
243 for limit in self.limits.iter_mut() {
244 assert!(limit.remove(ip));
245 }
246
247 true
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::{new_session_ip_limit, SessionIpLimit};
254 use std::{convert::TryInto, net::IpAddr, str::FromStr};
255
256 fn new_ip(ip: &'static str) -> IpAddr { IpAddr::from_str(ip).unwrap() }
257
258 fn new_limit(config: &str) -> Box<dyn SessionIpLimit> {
259 let config: String = config.into();
260 new_session_ip_limit(&config.try_into().unwrap())
261 }
262
263 #[test]
264 fn test_noop() {
265 let mut limit = new_limit("0,0,0,0");
266 assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), true);
267 assert_eq!(limit.add(new_ip("127.0.0.1")), true);
268 assert_eq!(limit.remove(&new_ip("127.0.0.2")), true);
269 }
270
271 #[test]
272 fn test_single_ip() {
273 let mut limit = new_limit("1,0,0,0");
274
275 assert_eq!(limit.remove(&new_ip("127.0.0.1")), false);
276
277 assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), true);
278 assert_eq!(limit.add(new_ip("127.0.0.1")), true);
279
280 assert_eq!(limit.is_allowed(&new_ip("127.0.0.1")), false);
281 assert_eq!(limit.add(new_ip("127.0.0.1")), false);
282
283 assert_eq!(limit.is_allowed(&new_ip("127.0.0.2")), true);
284 assert_eq!(limit.add(new_ip("127.0.0.2")), true);
285 }
286
287 #[test]
288 fn test_subnet_all() {
289 let mut limit = new_limit("0,3,2,1");
290
291 assert_eq!(limit.add(new_ip("127.0.0.1")), true);
292
293 assert_eq!(limit.add(new_ip("127.0.0.2")), false);
295
296 assert_eq!(limit.add(new_ip("127.0.1.1")), true);
298 assert_eq!(limit.add(new_ip("127.0.1.1")), false);
299 assert_eq!(limit.add(new_ip("127.0.1.2")), false);
300 assert_eq!(limit.add(new_ip("127.0.2.1")), false);
301
302 assert_eq!(limit.add(new_ip("192.168.0.1")), true);
304 assert_eq!(limit.add(new_ip("192.169.0.1")), true);
305 assert_eq!(limit.add(new_ip("192.170.0.1")), true);
306 assert_eq!(limit.add(new_ip("192.171.0.1")), false);
307 }
308
309 #[test]
310 fn test_subnet_b() {
311 let mut limit = new_limit("0,0,2,0");
312
313 assert_eq!(limit.add(new_ip("127.0.0.1")), true);
314 assert_eq!(limit.add(new_ip("127.0.0.2")), true);
315 assert_eq!(limit.add(new_ip("127.0.0.3")), false);
316 assert_eq!(limit.add(new_ip("127.0.1.1")), false);
317 assert_eq!(limit.add(new_ip("127.1.0.1")), true);
318 assert_eq!(limit.add(new_ip("127.2.0.1")), true);
319 assert_eq!(limit.add(new_ip("127.3.0.1")), true);
320 }
321}