1#![deny(missing_docs)]
46#![allow(clippy::all)]
47
48mod atomic;
49mod iter;
50mod ops;
51mod util;
52
53pub use atomic::AtomicBitSet;
54pub use iter::{BitIter, DrainBitIter};
55#[cfg(feature = "parallel")]
56pub use iter::{BitParIter, BitProducer};
57pub use ops::{BitSetAll, BitSetAnd, BitSetNot, BitSetOr, BitSetXor};
58
59use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
60use util::*;
61
62#[derive(Clone, Debug, Default)]
68pub struct BitSet {
69 num: usize,
70 layer3: usize,
71 layer2: Vec<usize>,
72 layer1: Vec<usize>,
73 layer0: Vec<usize>,
74}
75
76impl MallocSizeOf for BitSet {
77 fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
78 self.layer2.size_of(ops)
79 + self.layer1.size_of(ops)
80 + self.layer0.size_of(ops)
81 }
82}
83
84impl BitSet {
85 pub fn new() -> BitSet { Default::default() }
87
88 #[inline(always)]
89 fn valid_range(max: Index) {
90 if (MAX_EID as u32) < max {
91 panic!("Expected index to be less then {}, found {}", MAX_EID, max);
92 }
93 }
94
95 pub fn with_capacity(max: Index) -> BitSet {
97 Self::valid_range(max);
98 let mut value = BitSet::new();
99 value.extend(max);
100 value
101 }
102
103 #[inline(never)]
104 fn extend(&mut self, id: Index) {
105 Self::valid_range(id);
106 let (p0, p1, p2) = offsets(id);
107
108 Self::fill_up(&mut self.layer2, p2);
109 Self::fill_up(&mut self.layer1, p1);
110 Self::fill_up(&mut self.layer0, p0);
111 }
112
113 fn fill_up(vec: &mut Vec<usize>, upper_index: usize) {
114 if vec.len() <= upper_index {
115 vec.resize(upper_index + 1, 0);
116 }
117 }
118
119 #[inline(never)]
122 fn add_slow(&mut self, id: Index) {
123 let (_, p1, p2) = offsets(id);
124 self.layer1[p1] |= id.mask(SHIFT1);
125 self.layer2[p2] |= id.mask(SHIFT2);
126 self.layer3 |= id.mask(SHIFT3);
127 }
128
129 #[inline(always)]
132 pub fn add(&mut self, id: Index) -> bool {
133 let (p0, mask) = (id.offset(SHIFT1), id.mask(SHIFT0));
134
135 if p0 >= self.layer0.len() {
136 self.extend(id);
137 }
138
139 if self.layer0[p0] & mask != 0 {
140 return true;
141 }
142
143 let old = self.layer0[p0];
146 self.layer0[p0] |= mask;
147 if old == 0 {
148 self.add_slow(id);
149 } else {
150 self.layer0[p0] |= mask;
151 }
152 self.num += 1;
153 false
154 }
155
156 #[allow(unused)]
157 fn layer_mut(&mut self, level: usize, idx: usize) -> &mut usize {
158 match level {
159 0 => {
160 Self::fill_up(&mut self.layer0, idx);
161 &mut self.layer0[idx]
162 }
163 1 => {
164 Self::fill_up(&mut self.layer1, idx);
165 &mut self.layer1[idx]
166 }
167 2 => {
168 Self::fill_up(&mut self.layer2, idx);
169 &mut self.layer2[idx]
170 }
171 3 => &mut self.layer3,
172 _ => panic!("Invalid layer: {}", level),
173 }
174 }
175
176 #[inline(always)]
180 pub fn remove(&mut self, id: Index) -> bool {
181 let (p0, p1, p2) = offsets(id);
182
183 if p0 >= self.layer0.len() {
184 return false;
185 }
186
187 if self.layer0[p0] & id.mask(SHIFT0) == 0 {
188 return false;
189 }
190
191 self.layer0[p0] &= !id.mask(SHIFT0);
196 if self.layer0[p0] != 0 {
197 self.num -= 1;
198 return true;
199 }
200
201 self.layer1[p1] &= !id.mask(SHIFT1);
202 if self.layer1[p1] != 0 {
203 self.num -= 1;
204 return true;
205 }
206
207 self.layer2[p2] &= !id.mask(SHIFT2);
208 if self.layer2[p2] != 0 {
209 self.num -= 1;
210 return true;
211 }
212
213 self.layer3 &= !id.mask(SHIFT3);
214 self.num -= 1;
215 return true;
216 }
217
218 #[inline(always)]
220 pub fn contains(&self, id: Index) -> bool {
221 let p0 = id.offset(SHIFT1);
222 p0 < self.layer0.len() && (self.layer0[p0] & id.mask(SHIFT0)) != 0
223 }
224
225 #[inline(always)]
227 pub fn contains_set(&self, other: &BitSet) -> bool {
228 for id in other.iter() {
229 if !self.contains(id) {
230 return false;
231 }
232 }
233 true
234 }
235
236 pub fn clear(&mut self) {
238 self.layer0.clear();
239 self.layer1.clear();
240 self.layer2.clear();
241 self.layer3 = 0;
242 self.num = 0;
243 }
244
245 #[inline(always)]
247 pub fn len(&self) -> usize { self.num }
248}
249
250pub trait BitSetLike {
264 fn get_from_layer(&self, layer: usize, idx: usize) -> usize {
268 match layer {
269 0 => self.layer0(idx),
270 1 => self.layer1(idx),
271 2 => self.layer2(idx),
272 3 => self.layer3(),
273 _ => panic!("Invalid layer: {}", layer),
274 }
275 }
276
277 fn is_empty(&self) -> bool { self.layer3() == 0 }
279
280 fn layer3(&self) -> usize;
283
284 fn layer2(&self, i: usize) -> usize;
287
288 fn layer1(&self, i: usize) -> usize;
291
292 fn layer0(&self, i: usize) -> usize;
295
296 fn contains(&self, i: Index) -> bool;
298
299 fn iter(self) -> BitIter<Self>
301 where Self: Sized {
302 let layer3 = self.layer3();
303
304 BitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
305 }
306
307 #[cfg(feature = "parallel")]
309 fn par_iter(self) -> BitParIter<Self>
310 where Self: Sized {
311 BitParIter::new(self)
312 }
313}
314
315pub trait DrainableBitSet: BitSetLike {
317 fn remove(&mut self, i: Index) -> bool;
321
322 fn drain<'a>(&'a mut self) -> DrainBitIter<'a, Self>
325 where Self: Sized {
326 let layer3 = self.layer3();
327
328 DrainBitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
329 }
330}
331
332impl<'a, T> BitSetLike for &'a T
333where T: BitSetLike + ?Sized
334{
335 #[inline(always)]
336 fn layer3(&self) -> usize { (*self).layer3() }
337
338 #[inline(always)]
339 fn layer2(&self, i: usize) -> usize { (*self).layer2(i) }
340
341 #[inline(always)]
342 fn layer1(&self, i: usize) -> usize { (*self).layer1(i) }
343
344 #[inline(always)]
345 fn layer0(&self, i: usize) -> usize { (*self).layer0(i) }
346
347 #[inline(always)]
348 fn contains(&self, i: Index) -> bool { (*self).contains(i) }
349}
350
351impl<'a, T> BitSetLike for &'a mut T
352where T: BitSetLike + ?Sized
353{
354 #[inline(always)]
355 fn layer3(&self) -> usize { (**self).layer3() }
356
357 #[inline(always)]
358 fn layer2(&self, i: usize) -> usize { (**self).layer2(i) }
359
360 #[inline(always)]
361 fn layer1(&self, i: usize) -> usize { (**self).layer1(i) }
362
363 #[inline(always)]
364 fn layer0(&self, i: usize) -> usize { (**self).layer0(i) }
365
366 #[inline(always)]
367 fn contains(&self, i: Index) -> bool { (**self).contains(i) }
368}
369
370impl<'a, T> DrainableBitSet for &'a mut T
371where T: DrainableBitSet
372{
373 #[inline(always)]
374 fn remove(&mut self, i: Index) -> bool { (**self).remove(i) }
375}
376
377impl BitSetLike for BitSet {
378 #[inline(always)]
379 fn layer3(&self) -> usize { self.layer3 }
380
381 #[inline(always)]
382 fn layer2(&self, i: usize) -> usize {
383 self.layer2.get(i).map(|&x| x).unwrap_or(0)
384 }
385
386 #[inline(always)]
387 fn layer1(&self, i: usize) -> usize {
388 self.layer1.get(i).map(|&x| x).unwrap_or(0)
389 }
390
391 #[inline(always)]
392 fn layer0(&self, i: usize) -> usize {
393 self.layer0.get(i).map(|&x| x).unwrap_or(0)
394 }
395
396 #[inline(always)]
397 fn contains(&self, i: Index) -> bool { self.contains(i) }
398}
399
400impl DrainableBitSet for BitSet {
401 #[inline(always)]
402 fn remove(&mut self, i: Index) -> bool { self.remove(i) }
403}
404
405impl PartialEq for BitSet {
406 #[inline(always)]
407 fn eq(&self, rhv: &BitSet) -> bool {
408 if self.layer3 != rhv.layer3 {
409 return false;
410 }
411 if self.layer2.len() != rhv.layer2.len()
412 || self.layer1.len() != rhv.layer1.len()
413 || self.layer0.len() != rhv.layer0.len()
414 {
415 return false;
416 }
417
418 for i in 0..self.layer2.len() {
419 if self.layer2(i) != rhv.layer2(i) {
420 return false;
421 }
422 }
423 for i in 0..self.layer1.len() {
424 if self.layer1(i) != rhv.layer1(i) {
425 return false;
426 }
427 }
428 for i in 0..self.layer0.len() {
429 if self.layer0(i) != rhv.layer0(i) {
430 return false;
431 }
432 }
433
434 true
435 }
436}
437impl Eq for BitSet {}
438
439#[cfg(test)]
440mod tests {
441 use rand::rng;
442
443 use super::{BitSet, BitSetAnd, BitSetLike, BitSetNot, BITS};
444
445 #[test]
446 fn insert() {
447 let mut c = BitSet::new();
448 for i in 0..1_000 {
449 assert!(!c.add(i));
450 assert!(c.add(i));
451 }
452
453 for i in 0..1_000 {
454 assert!(c.contains(i));
455 }
456 }
457
458 #[test]
459 fn insert_100k() {
460 let mut c = BitSet::new();
461 for i in 0..100_000 {
462 assert!(!c.add(i));
463 assert!(c.add(i));
464 }
465
466 for i in 0..100_000 {
467 assert!(c.contains(i));
468 }
469 }
470 #[test]
471 fn remove() {
472 let mut c = BitSet::new();
473 for i in 0..1_000 {
474 assert!(!c.add(i));
475 }
476
477 for i in 0..1_000 {
478 assert!(c.contains(i));
479 assert!(c.remove(i));
480 assert!(!c.contains(i));
481 assert!(!c.remove(i));
482 }
483 }
484
485 #[test]
486 fn iter() {
487 let mut c = BitSet::new();
488 for i in 0..100_000 {
489 c.add(i);
490 }
491
492 let mut count = 0;
493 for (idx, i) in c.iter().enumerate() {
494 count += 1;
495 assert_eq!(idx, i as usize);
496 }
497 assert_eq!(count, 100_000);
498 }
499
500 #[test]
501 fn iter_odd_even() {
502 let mut odd = BitSet::new();
503 let mut even = BitSet::new();
504 for i in 0..100_000 {
505 if i % 2 == 1 {
506 odd.add(i);
507 } else {
508 even.add(i);
509 }
510 }
511
512 assert_eq!((&odd).iter().count(), 50_000);
513 assert_eq!((&even).iter().count(), 50_000);
514 assert_eq!(BitSetAnd(&odd, &even).iter().count(), 0);
515 }
516
517 #[test]
518 fn iter_random_add() {
519 use rand::prelude::*;
520
521 let mut set = BitSet::new();
522 let mut rng = rng();
523 let limit = 1_048_576;
524 let mut added = 0;
525 for _ in 0..(limit / 10) {
526 let index = rng.random_range(0..limit);
527 if !set.add(index) {
528 added += 1;
529 }
530 }
531 assert_eq!(set.iter().count(), added as usize);
532 }
533
534 #[test]
535 fn iter_clusters() {
536 let mut set = BitSet::new();
537 for x in 0..8 {
538 let x = (x * 3) << (BITS * 2); for y in 0..8 {
540 let y = (y * 3) << (BITS);
541 for z in 0..8 {
542 let z = z * 2;
543 set.add(x + y + z);
544 }
545 }
546 }
547 assert_eq!(set.iter().count(), 8usize.pow(3));
548 }
549
550 #[test]
551 fn not() {
552 let mut c = BitSet::new();
553 for i in 0..10_000 {
554 if i % 2 == 1 {
555 c.add(i);
556 }
557 }
558 let d = BitSetNot(c);
559 for (idx, i) in d.iter().take(5_000).enumerate() {
560 assert_eq!(idx * 2, i as usize);
561 }
562 }
563}
564
565#[cfg(all(test, feature = "parallel"))]
566mod test_parallel {
567 use super::{BitSet, BitSetAnd, BitSetLike, BITS};
568 use rand::rng;
569 use rayon::iter::ParallelIterator;
570
571 #[test]
572 fn par_iter_one() {
573 let step = 5000;
574 let tests = 1_048_576 / step;
575 for n in 0..tests {
576 let n = n * step;
577 let mut set = BitSet::new();
578 set.add(n);
579 assert_eq!(set.par_iter().count(), 1);
580 }
581 let mut set = BitSet::new();
582 set.add(1_048_576 - 1);
583 assert_eq!(set.par_iter().count(), 1);
584 }
585
586 #[test]
587 fn par_iter_random_add() {
588 use rand::prelude::*;
589 use std::{
590 collections::HashSet,
591 sync::{Arc, Mutex},
592 };
593
594 let mut set = BitSet::new();
595 let mut check_set = HashSet::new();
596 let mut rng = rng();
597 let limit = 1_048_576;
598 for _ in 0..(limit / 10) {
599 let index = rng.random_range(0..limit);
600 set.add(index);
601 check_set.insert(index);
602 }
603 let check_set = Arc::new(Mutex::new(check_set));
604 let missing_set = Arc::new(Mutex::new(HashSet::new()));
605 set.par_iter().for_each(|n| {
606 let check_set = check_set.clone();
607 let missing_set = missing_set.clone();
608 let mut check = check_set.lock().unwrap();
609 if !check.remove(&n) {
610 let mut missing = missing_set.lock().unwrap();
611 missing.insert(n);
612 }
613 });
614 let check_set = check_set.lock().unwrap();
615 let missing_set = missing_set.lock().unwrap();
616 if !check_set.is_empty() && !missing_set.is_empty() {
617 panic!(
618 "There were values that didn't get iterated: {:?}
619 There were values that got iterated, but that shouldn't be: {:?}",
620 *check_set, *missing_set
621 );
622 }
623 if !check_set.is_empty() {
624 panic!(
625 "There were values that didn't get iterated: {:?}",
626 *check_set
627 );
628 }
629 if !missing_set.is_empty() {
630 panic!(
631 "There were values that got iterated, but that shouldn't be: {:?}",
632 *missing_set
633 );
634 }
635 }
636
637 #[test]
638 fn par_iter_odd_even() {
639 let mut odd = BitSet::new();
640 let mut even = BitSet::new();
641 for i in 0..100_000 {
642 if i % 2 == 1 {
643 odd.add(i);
644 } else {
645 even.add(i);
646 }
647 }
648
649 assert_eq!((&odd).par_iter().count(), 50_000);
650 assert_eq!((&even).par_iter().count(), 50_000);
651 assert_eq!(BitSetAnd(&odd, &even).par_iter().count(), 0);
652 }
653
654 #[test]
655 fn par_iter_clusters() {
656 use std::{
657 collections::HashSet,
658 sync::{Arc, Mutex},
659 };
660 let mut set = BitSet::new();
661 let mut check_set = HashSet::new();
662 for x in 0..8 {
663 let x = (x * 3) << (BITS * 2); for y in 0..8 {
665 let y = (y * 3) << (BITS);
666 for z in 0..8 {
667 let z = z * 2;
668 let index = x + y + z;
669 set.add(index);
670 check_set.insert(index);
671 }
672 }
673 }
674 let check_set = Arc::new(Mutex::new(check_set));
675 let missing_set = Arc::new(Mutex::new(HashSet::new()));
676 set.par_iter().for_each(|n| {
677 let check_set = check_set.clone();
678 let missing_set = missing_set.clone();
679 let mut check = check_set.lock().unwrap();
680 if !check.remove(&n) {
681 let mut missing = missing_set.lock().unwrap();
682 missing.insert(n);
683 }
684 });
685 let check_set = check_set.lock().unwrap();
686 let missing_set = missing_set.lock().unwrap();
687 if !check_set.is_empty() && !missing_set.is_empty() {
688 panic!(
689 "There were values that didn't get iterated: {:?}
690 There were values that got iterated, but that shouldn't be: {:?}",
691 *check_set, *missing_set
692 );
693 }
694 if !check_set.is_empty() {
695 panic!(
696 "There were values that didn't get iterated: {:?}",
697 *check_set
698 );
699 }
700 if !missing_set.is_empty() {
701 panic!(
702 "There were values that got iterated, but that shouldn't be: {:?}",
703 *missing_set
704 );
705 }
706 }
707}