1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use std::{
    collections::{HashMap, HashSet},
    hash::Hash,
};

use super::CheckpointEntry;

pub trait CheckpointLayerTrait {
    type Key: Eq + Hash + Clone;
    type Value;

    fn as_hash_map(&self) -> &HashMap<Self::Key, CheckpointEntry<Self::Value>>;

    fn as_hash_map_mut(
        &mut self,
    ) -> &mut HashMap<Self::Key, CheckpointEntry<Self::Value>>;
}

#[derive(Debug, Clone)]
pub struct LazyDiscardedVec<T: CheckpointLayerTrait> {
    inner_vec: Vec<T>,
    bundle_start_indices: Vec<usize>,
}

impl<T: CheckpointLayerTrait> Default for LazyDiscardedVec<T> {
    fn default() -> Self {
        Self {
            inner_vec: Default::default(),
            bundle_start_indices: Default::default(),
        }
    }
}

impl<T: CheckpointLayerTrait> LazyDiscardedVec<T> {
    fn total_len(&self) -> usize { self.inner_vec.len() }

    pub fn is_empty(&self) -> bool {
        if self.bundle_start_indices.is_empty() {
            assert!(self.inner_vec.is_empty());
            true
        } else {
            false
        }
    }

    pub fn push_checkpoint(&mut self, new_element: T) -> usize {
        self.bundle_start_indices.push(self.total_len());
        self.inner_vec.push(new_element);
        self.bundle_start_indices.len() - 1
    }

    pub fn discard_checkpoint(&mut self) -> Option<HashSet<T::Key>> {
        self.bundle_start_indices.pop()?;

        if self.bundle_start_indices.is_empty() {
            Some(self.discard_all_checkpoints())
        } else {
            None
        }
    }

    fn discard_all_checkpoints(&mut self) -> HashSet<T::Key> {
        let cleared_keys = self
            .inner_vec
            .iter()
            .flat_map(|x| x.as_hash_map().keys())
            .cloned()
            .collect();
        self.clear();
        cleared_keys
    }

    pub fn revert_to_checkpoint(
        &mut self,
    ) -> Option<impl Iterator<Item = (usize, T)>> {
        let first_layer_id = self.bundle_start_indices.pop()?;
        assert!(first_layer_id < self.total_len());
        let last_layer_id = self.total_len() - 1;

        let reverted_layers = self.inner_vec.split_off(first_layer_id);

        Some(
            (first_layer_id..=last_layer_id)
                .rev()
                .zip(reverted_layers.into_iter().rev()),
        )
    }

    pub fn clear(&mut self) {
        self.inner_vec.clear();
        self.bundle_start_indices.clear();
    }

    pub fn insert_element(
        &mut self, key: T::Key,
        value: impl FnOnce(usize) -> CheckpointEntry<T::Value>,
    ) {
        if self.is_empty() {
            return;
        }

        let last_layer_id = self.inner_vec.len() - 1;
        let last_element = self.inner_vec.last_mut().unwrap();
        last_element
            .as_hash_map_mut()
            .entry(key)
            .or_insert_with(|| value(last_layer_id));
    }

    #[cfg(test)]
    fn undiscarded_len(&self) -> usize { self.bundle_start_indices.len() }

    #[cfg(test)]
    pub fn len(&self) -> usize { self.undiscarded_len() }

    #[cfg(test)]
    pub fn elements_from_index(
        &self, undiscard_element_index: usize,
    ) -> impl Iterator<Item = &T> {
        let element_index = if undiscard_element_index < self.undiscarded_len()
        {
            self.bundle_start_indices[undiscard_element_index]
        } else {
            self.total_len()
        };
        self.inner_vec.iter().skip(element_index)
    }
}