diem_jellyfish_merkle/nibble_path/
mod.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4// Copyright 2021 Conflux Foundation. All rights reserved.
5// Conflux is free software and distributed under GNU General Public License.
6// See http://www.gnu.org/licenses/
7
8//! NibblePath library simplify operations with nibbles in a compact format for
9//! modified sparse Merkle tree by providing powerful iterators advancing by
10//! either bit or nibble.
11
12#[cfg(test)]
13mod nibble_path_test;
14
15use crate::ROOT_NIBBLE_HEIGHT;
16use diem_nibble::Nibble;
17use mirai_annotations::*;
18#[cfg(any(test, feature = "fuzzing"))]
19use proptest::{collection::vec, prelude::*};
20use serde::{Deserialize, Serialize};
21use std::{fmt, iter::FromIterator};
22
23/// NibblePath defines a path in Merkle tree in the unit of nibble (4 bits).
24#[derive(
25    Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
26)]
27pub struct NibblePath {
28    /// Indicates the total number of nibbles in bytes. Either `bytes.len() * 2
29    /// - 1` or `bytes.len() * 2`.
30    // Guarantees intended ordering based on the top-to-bottom declaration
31    // order of the struct's members.
32    num_nibbles: usize,
33    /// The underlying bytes that stores the path, 2 nibbles per byte. If the
34    /// number of nibbles is odd, the second half of the last byte must be
35    /// 0.
36    bytes: Vec<u8>,
37    // invariant num_nibbles <= ROOT_NIBBLE_HEIGHT
38}
39
40/// Supports debug format by concatenating nibbles literally. For example,
41/// [0x12, 0xa0] with 3 nibbles will be printed as "12a".
42impl fmt::Debug for NibblePath {
43    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44        self.nibbles().try_for_each(|x| write!(f, "{:x}", x))
45    }
46}
47
48/// Convert a vector of bytes into `NibblePath` using the lower 4 bits of each
49/// byte as nibble.
50impl FromIterator<Nibble> for NibblePath {
51    fn from_iter<I: IntoIterator<Item = Nibble>>(iter: I) -> Self {
52        let mut nibble_path = NibblePath::new(vec![]);
53        for nibble in iter {
54            nibble_path.push(nibble);
55        }
56        nibble_path
57    }
58}
59
60#[cfg(any(test, feature = "fuzzing"))]
61impl Arbitrary for NibblePath {
62    type Parameters = ();
63    type Strategy = BoxedStrategy<Self>;
64
65    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
66        arb_nibble_path().boxed()
67    }
68}
69
70#[cfg(any(test, feature = "fuzzing"))]
71prop_compose! {
72    fn arb_nibble_path()(
73        mut bytes in vec(any::<u8>(), 0..=ROOT_NIBBLE_HEIGHT/2),
74        is_odd in any::<bool>()
75    ) -> NibblePath {
76        if let Some(last_byte) = bytes.last_mut() {
77            if is_odd {
78                *last_byte &= 0xf0;
79                return NibblePath::new_odd(bytes);
80            }
81        }
82        NibblePath::new(bytes)
83    }
84}
85
86#[cfg(any(test, feature = "fuzzing"))]
87prop_compose! {
88    fn arb_internal_nibble_path()(
89        nibble_path in arb_nibble_path().prop_filter(
90            "Filter out leaf paths.",
91            |p| p.num_nibbles() < ROOT_NIBBLE_HEIGHT,
92        )
93    ) -> NibblePath {
94        nibble_path
95    }
96}
97
98impl NibblePath {
99    /// Creates a new `NibblePath` from a vector of bytes assuming each byte has
100    /// 2 nibbles.
101    pub fn new(bytes: Vec<u8>) -> Self {
102        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
103        let num_nibbles = bytes.len() * 2;
104        NibblePath { bytes, num_nibbles }
105    }
106
107    /// Similar to `new()` but assumes that the bytes have one less nibble.
108    pub fn new_odd(bytes: Vec<u8>) -> Self {
109        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
110        assert_eq!(
111            bytes.last().expect("Should have odd number of nibbles.") & 0x0f,
112            0,
113            "Last nibble must be 0."
114        );
115        let num_nibbles = bytes.len() * 2 - 1;
116        NibblePath { bytes, num_nibbles }
117    }
118
119    /// Adds a nibble to the end of the nibble path.
120    pub fn push(&mut self, nibble: Nibble) {
121        assert!(ROOT_NIBBLE_HEIGHT > self.num_nibbles);
122        if self.num_nibbles % 2 == 0 {
123            self.bytes.push(u8::from(nibble) << 4);
124        } else {
125            self.bytes[self.num_nibbles / 2] |= u8::from(nibble);
126        }
127        self.num_nibbles += 1;
128    }
129
130    /// Pops a nibble from the end of the nibble path.
131    pub fn pop(&mut self) -> Option<Nibble> {
132        let poped_nibble = if self.num_nibbles % 2 == 0 {
133            self.bytes.last_mut().map(|last_byte| {
134                let nibble = *last_byte & 0x0f;
135                *last_byte &= 0xf0;
136                Nibble::from(nibble)
137            })
138        } else {
139            self.bytes.pop().map(|byte| Nibble::from(byte >> 4))
140        };
141        if poped_nibble.is_some() {
142            self.num_nibbles -= 1;
143        }
144        poped_nibble
145    }
146
147    /// Returns the last nibble.
148    pub fn last(&self) -> Option<Nibble> {
149        let last_byte_option = self.bytes.last();
150        if self.num_nibbles % 2 == 0 {
151            last_byte_option.map(|last_byte| Nibble::from(*last_byte & 0x0f))
152        } else {
153            let last_byte = last_byte_option
154                .expect("Last byte must exist if num_nibbles is odd.");
155            Some(Nibble::from(*last_byte >> 4))
156        }
157    }
158
159    /// Get the i-th bit.
160    fn get_bit(&self, i: usize) -> bool {
161        assert!(i / 4 < self.num_nibbles);
162        let pos = i / 8;
163        let bit = 7 - i % 8;
164        ((self.bytes[pos] >> bit) & 1) != 0
165    }
166
167    /// Get the i-th nibble.
168    fn get_nibble(&self, i: usize) -> Nibble {
169        assert!(i < self.num_nibbles);
170        Nibble::from(
171            (self.bytes[i / 2] >> (if i % 2 == 1 { 0 } else { 4 })) & 0xf,
172        )
173    }
174
175    /// Get a bit iterator iterates over the whole nibble path.
176    pub fn bits(&self) -> BitIterator<'_> {
177        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
178        BitIterator {
179            nibble_path: self,
180            pos: (0..self.num_nibbles * 4),
181        }
182    }
183
184    /// Get a nibble iterator iterates over the whole nibble path.
185    pub fn nibbles(&self) -> NibbleIterator<'_> {
186        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
187        NibbleIterator::new(self, 0, self.num_nibbles)
188    }
189
190    /// Get the total number of nibbles stored.
191    pub fn num_nibbles(&self) -> usize { self.num_nibbles }
192
193    /// Get the underlying bytes storing nibbles.
194    pub fn bytes(&self) -> &[u8] { &self.bytes }
195}
196
197pub trait Peekable: Iterator {
198    /// Returns the `next()` value without advancing the iterator.
199    fn peek(&self) -> Option<Self::Item>;
200}
201
202/// BitIterator iterates a nibble path by bit.
203pub struct BitIterator<'a> {
204    nibble_path: &'a NibblePath,
205    pos: std::ops::Range<usize>,
206}
207
208impl<'a> Peekable for BitIterator<'a> {
209    /// Returns the `next()` value without advancing the iterator.
210    fn peek(&self) -> Option<Self::Item> {
211        if self.pos.start < self.pos.end {
212            Some(self.nibble_path.get_bit(self.pos.start))
213        } else {
214            None
215        }
216    }
217}
218
219/// BitIterator spits out a boolean each time. True/false denotes 1/0.
220impl<'a> Iterator for BitIterator<'a> {
221    type Item = bool;
222
223    fn next(&mut self) -> Option<Self::Item> {
224        self.pos.next().map(|i| self.nibble_path.get_bit(i))
225    }
226}
227
228/// Support iterating bits in reversed order.
229impl<'a> DoubleEndedIterator for BitIterator<'a> {
230    fn next_back(&mut self) -> Option<Self::Item> {
231        self.pos.next_back().map(|i| self.nibble_path.get_bit(i))
232    }
233}
234
235/// NibbleIterator iterates a nibble path by nibble.
236#[derive(Debug)]
237pub struct NibbleIterator<'a> {
238    /// The underlying nibble path that stores the nibbles
239    nibble_path: &'a NibblePath,
240
241    /// The current index, `pos.start`, will bump by 1 after calling `next()`
242    /// until `pos.start == pos.end`.
243    pos: std::ops::Range<usize>,
244
245    /// The start index of the iterator. At the beginning, `pos.start ==
246    /// start`. [start, pos.end) defines the range of `nibble_path` this
247    /// iterator iterates over. `nibble_path` refers to the entire
248    /// underlying buffer but the range may only be partial.
249    start: usize,
250    /* invariant self.start <= self.pos.start;
251     * invariant self.pos.start <= self.pos.end;
252     * invariant self.pos.end <= ROOT_NIBBLE_HEIGHT; */
253}
254
255/// NibbleIterator spits out a byte each time. Each byte must be in range [0,
256/// 16).
257impl<'a> Iterator for NibbleIterator<'a> {
258    type Item = Nibble;
259
260    fn next(&mut self) -> Option<Self::Item> {
261        self.pos.next().map(|i| self.nibble_path.get_nibble(i))
262    }
263}
264
265impl<'a> Peekable for NibbleIterator<'a> {
266    /// Returns the `next()` value without advancing the iterator.
267    fn peek(&self) -> Option<Self::Item> {
268        if self.pos.start < self.pos.end {
269            Some(self.nibble_path.get_nibble(self.pos.start))
270        } else {
271            None
272        }
273    }
274}
275
276impl<'a> NibbleIterator<'a> {
277    fn new(nibble_path: &'a NibblePath, start: usize, end: usize) -> Self {
278        precondition!(start <= end);
279        precondition!(start <= ROOT_NIBBLE_HEIGHT);
280        precondition!(end <= ROOT_NIBBLE_HEIGHT);
281        Self {
282            nibble_path,
283            pos: (start..end),
284            start,
285        }
286    }
287
288    /// Returns a nibble iterator that iterates all visited nibbles.
289    pub fn visited_nibbles(&self) -> NibbleIterator<'a> {
290        assume!(self.start <= self.pos.start); // invariant
291        assume!(self.pos.start <= ROOT_NIBBLE_HEIGHT); // invariant
292        Self::new(self.nibble_path, self.start, self.pos.start)
293    }
294
295    /// Returns a nibble iterator that iterates all remaining nibbles.
296    pub fn remaining_nibbles(&self) -> NibbleIterator<'a> {
297        assume!(self.pos.start <= self.pos.end); // invariant
298        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
299        Self::new(self.nibble_path, self.pos.start, self.pos.end)
300    }
301
302    /// Turn it into a `BitIterator`.
303    pub fn bits(&self) -> BitIterator<'a> {
304        assume!(self.pos.start <= self.pos.end); // invariant
305        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
306        BitIterator {
307            nibble_path: self.nibble_path,
308            pos: (self.pos.start * 4..self.pos.end * 4),
309        }
310    }
311
312    /// Cut and return the range of the underlying `nibble_path` that this
313    /// iterator is iterating over as a new `NibblePath`
314    pub fn get_nibble_path(&self) -> NibblePath {
315        self.visited_nibbles()
316            .chain(self.remaining_nibbles())
317            .collect()
318    }
319
320    /// Get the number of nibbles that this iterator covers.
321    pub fn num_nibbles(&self) -> usize {
322        assume!(self.start <= self.pos.end); // invariant
323        self.pos.end - self.start
324    }
325
326    /// Return `true` if the iteration is over.
327    pub fn is_finished(&self) -> bool { self.peek().is_none() }
328}
329
330/// Advance both iterators if their next nibbles are the same until either
331/// reaches the end or the find a mismatch. Return the number of matched
332/// nibbles.
333pub fn skip_common_prefix<'a, 'b, I1: 'a, I2: 'b>(
334    x: &'a mut I1, y: &mut I2,
335) -> usize
336where
337    I1: Iterator + Peekable,
338    I2: Iterator + Peekable,
339    <I1 as Iterator>::Item: std::cmp::PartialEq<<I2 as Iterator>::Item>,
340{
341    let mut count = 0;
342    loop {
343        let x_peek = x.peek();
344        let y_peek = y.peek();
345        if x_peek.is_none()
346            || y_peek.is_none()
347            || x_peek.expect("cannot be none")
348                != y_peek.expect("cannot be none")
349        {
350            break;
351        }
352        count += 1;
353        x.next();
354        y.next();
355    }
356    count
357}