cfx_math/nth_root/
compute.rs

1use cfx_types::{U256, U512};
2use std::{
3    convert::TryFrom,
4    fmt::Debug,
5    ops::{Add, Div, Mul, Shl, Shr, Sub},
6};
7use typenum::Unsigned;
8
9use super::{const_generic::SubU1, root_degree::RootDegree};
10use unroll::unroll_for_loops;
11
12pub trait NthRoot:
13    Copy
14    + Mul<Output = Self>
15    + Ord
16    + Shl<usize, Output = Self>
17    + Shr<usize, Output = Self>
18    + Div<Output = Self>
19    + Add<Output = Self>
20    + Sub<Output = Self>
21    + From<u64>
22    + Debug
23{
24    const BITS: usize;
25    const MAX: Self;
26
27    fn checked_mul(self, other: Self) -> Option<Self>;
28    fn mul_usize(self, other: usize) -> Self;
29    fn div_usize(self, other: usize) -> Self;
30    fn bits(self) -> usize;
31    fn init_root<N: RootDegree>(self) -> InitRoot<Self>;
32
33    #[inline]
34    fn nth_root<N: RootDegree>(self) -> Self {
35        match self.init_root::<N>() {
36            InitRoot::Init(init_root) => {
37                newtons_method::<N, _>(self, init_root)
38            }
39            InitRoot::Done(root) => root,
40        }
41    }
42
43    #[inline]
44    fn truncate(self, next_bits: usize, multiply: usize) -> (Self, usize) {
45        let bits = self.bits();
46        let significant_bits = {
47            let n = multiply;
48            let adjust_bits = (n + (next_bits % n) - bits % n) % n;
49            next_bits - adjust_bits
50        };
51
52        // The `rest_bits` must be multiply of N.
53        let rest_bits = bits - significant_bits;
54        let significant_word = self >> rest_bits;
55        (significant_word, rest_bits)
56    }
57}
58
59pub enum InitRoot<I> {
60    Init(I),
61    Done(I),
62}
63
64#[inline]
65fn check_answer<const N: u32>(input: u64, output: u64) -> bool {
66    (output).checked_pow(N).map_or(false, |x| x <= input)
67        && (output + 1).checked_pow(N).map_or(true, |x| x > input)
68}
69
70impl NthRoot for u64 {
71    const BITS: usize = 64;
72    const MAX: u64 = u64::MAX;
73
74    #[inline]
75    fn checked_mul(self, other: Self) -> Option<Self> {
76        self.checked_mul(other)
77    }
78
79    #[inline]
80    fn mul_usize(self, other: usize) -> Self { self * (other as u64) }
81
82    #[inline]
83    fn div_usize(self, other: usize) -> Self { self / (other as u64) }
84
85    #[inline]
86    fn bits(self) -> usize { (u64::BITS - self.leading_zeros()) as usize }
87
88    #[inline]
89    fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
90        if self == 0 {
91            return InitRoot::Done(0);
92        }
93        if self < 1 << N::USIZE {
94            return InitRoot::Done(1);
95        }
96
97        if N::USIZE == 2 {
98            let ans = (self as f64).sqrt() as u64;
99            if check_answer::<2>(self, ans) {
100                return InitRoot::Done(ans);
101            }
102            return InitRoot::Init(ans + 1);
103        }
104
105        if N::USIZE == 4 {
106            let ans = (self as f64).sqrt().sqrt() as u64;
107            if check_answer::<4>(self, ans) {
108                return InitRoot::Done(ans);
109            }
110            return InitRoot::Init(ans + 1);
111        }
112
113        if N::LOOKUP_BITS > 0 {
114            if self < (1 << N::LOOKUP_BITS) - 1 {
115                InitRoot::Done(N::nth_root_lookup(self))
116            } else {
117                let (small, rot) =
118                    self.truncate(N::LOOKUP_BITS as usize, N::USIZE);
119                InitRoot::Init(
120                    (N::nth_root_lookup(small) + 1) << (rot / N::USIZE),
121                )
122            }
123        } else {
124            InitRoot::Init(
125                ((self as f64).ln() / f64::from(N::U32)).exp() as u64 + 1,
126            )
127        }
128    }
129}
130
131impl NthRoot for u128 {
132    const BITS: usize = 128;
133    const MAX: u128 = u128::MAX;
134
135    #[inline]
136    fn checked_mul(self, other: Self) -> Option<Self> {
137        self.checked_mul(other)
138    }
139
140    #[inline]
141    fn mul_usize(self, other: usize) -> Self { self * (other as u128) }
142
143    #[inline]
144    fn div_usize(self, other: usize) -> Self { self / (other as u128) }
145
146    #[inline]
147    fn bits(self) -> usize { (u128::BITS - self.leading_zeros()) as usize }
148
149    #[inline]
150    fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
151        let compute_next = |me: u128| (me as u64).nth_root::<N>() as u128;
152        if self < u64::MAX as u128 {
153            InitRoot::Done(compute_next(self))
154        } else {
155            InitRoot::Init({
156                let (next, rot) = self.truncate(64, N::USIZE);
157                (compute_next(next) + 1) << (rot / N::USIZE)
158            })
159        }
160    }
161}
162
163impl NthRoot for U256 {
164    const BITS: usize = 256;
165    const MAX: U256 = U256::MAX;
166
167    #[inline]
168    fn checked_mul(self, other: Self) -> Option<Self> {
169        self.checked_mul(other)
170    }
171
172    #[inline]
173    fn mul_usize(self, other: usize) -> Self { self * other }
174
175    #[inline]
176    fn div_usize(self, other: usize) -> Self { self / other }
177
178    #[inline]
179    fn bits(self) -> usize { U256::bits(&self) }
180
181    #[inline]
182    fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
183        let compute_next = |me: U256| U256::from(me.as_u128().nth_root::<N>());
184        if &self.0[2..4] == &[0, 0] {
185            InitRoot::Done(compute_next(self))
186        } else {
187            InitRoot::Init({
188                let (next, rot) = self.truncate(128, N::USIZE);
189                (compute_next(next) + 1) << (rot / N::USIZE)
190            })
191        }
192    }
193}
194
195impl NthRoot for U512 {
196    const BITS: usize = 512;
197    const MAX: U512 = U512::MAX;
198
199    #[inline]
200    fn checked_mul(self, other: Self) -> Option<Self> {
201        self.checked_mul(other)
202    }
203
204    #[inline]
205    fn mul_usize(self, other: usize) -> Self { self * other }
206
207    #[inline]
208    fn div_usize(self, other: usize) -> Self { self / other }
209
210    #[inline]
211    fn bits(self) -> usize { U512::bits(&self) }
212
213    #[inline]
214    fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
215        let compute_next =
216            |me: U512| U512::from(U256::try_from(me).unwrap().nth_root::<N>());
217        if &self.0[4..8] == &[0, 0, 0, 0] {
218            InitRoot::Done(compute_next(self))
219        } else {
220            InitRoot::Init({
221                let (next, rot) = self.truncate(256, N::USIZE);
222                (compute_next(next) + 1) << (rot / N::USIZE)
223            })
224        }
225    }
226}
227
228#[inline]
229fn newtons_method<N: RootDegree, I: NthRoot>(input: I, init_root: I) -> I {
230    let mut root = init_root;
231    loop {
232        let pow_n_1 = pow::<<N as SubU1>::Output, I>(root);
233        let pow_n = pow_n_1.checked_mul(root);
234
235        if pow_n.map_or(false, |x| x <= input) {
236            return root;
237        }
238
239        let mut fast_compute_root = None;
240        if I::BITS == 256 {
241            if let Some(pow_n) = pow_n {
242                let divisor = pow_n_1.mul_usize(N::USIZE);
243                fast_compute_root = Some(
244                    root - (pow_n - input - 1.into()) / divisor - 1.into(),
245                );
246            }
247        }
248        root = if let Some(root) = fast_compute_root {
249            root
250        } else {
251            (input / pow_n_1 + root.mul_usize(N::USIZE - 1)).div_usize(N::USIZE)
252        };
253    }
254}
255
256#[inline]
257#[allow(unused_assignments)]
258#[unroll_for_loops]
259pub(super) fn pow<N: Unsigned, I: Copy + From<u64> + Mul<Output = I>>(
260    input: I,
261) -> I {
262    let pow = N::USIZE;
263    match pow {
264        0 => {
265            return I::from(1u64);
266        }
267        1 => {
268            return input;
269        }
270        2 => {
271            return input * input;
272        }
273        3 => {
274            return input * input * input;
275        }
276        _ => {}
277    }
278
279    let mut base = input;
280    let mut acc = I::from(1u64);
281
282    for bit in 0u32..32 {
283        if (pow & (1 << bit)) > 0 {
284            acc = acc * base
285        }
286        if bit < 31 && (pow >> (bit + 1)) > 0 {
287            base = base * base
288        }
289    }
290    acc
291}