cfx_math/nth_root/
compute.rs1use cfx_types::{U256, U512};
2use std::{
3 convert::TryFrom,
4 fmt::Debug,
5 ops::{Add, Div, Mul, Shl, Shr, Sub},
6};
7use typenum::Unsigned;
8
9use super::{const_generic::SubU1, root_degree::RootDegree};
10use unroll::unroll_for_loops;
11
12pub trait NthRoot:
13 Copy
14 + Mul<Output = Self>
15 + Ord
16 + Shl<usize, Output = Self>
17 + Shr<usize, Output = Self>
18 + Div<Output = Self>
19 + Add<Output = Self>
20 + Sub<Output = Self>
21 + From<u64>
22 + Debug
23{
24 const BITS: usize;
25 const MAX: Self;
26
27 fn checked_mul(self, other: Self) -> Option<Self>;
28 fn mul_usize(self, other: usize) -> Self;
29 fn div_usize(self, other: usize) -> Self;
30 fn bits(self) -> usize;
31 fn init_root<N: RootDegree>(self) -> InitRoot<Self>;
32
33 #[inline]
34 fn nth_root<N: RootDegree>(self) -> Self {
35 match self.init_root::<N>() {
36 InitRoot::Init(init_root) => {
37 newtons_method::<N, _>(self, init_root)
38 }
39 InitRoot::Done(root) => root,
40 }
41 }
42
43 #[inline]
44 fn truncate(self, next_bits: usize, multiply: usize) -> (Self, usize) {
45 let bits = self.bits();
46 let significant_bits = {
47 let n = multiply;
48 let adjust_bits = (n + (next_bits % n) - bits % n) % n;
49 next_bits - adjust_bits
50 };
51
52 let rest_bits = bits - significant_bits;
54 let significant_word = self >> rest_bits;
55 (significant_word, rest_bits)
56 }
57}
58
59pub enum InitRoot<I> {
60 Init(I),
61 Done(I),
62}
63
64#[inline]
65fn check_answer<const N: u32>(input: u64, output: u64) -> bool {
66 (output).checked_pow(N).map_or(false, |x| x <= input)
67 && (output + 1).checked_pow(N).map_or(true, |x| x > input)
68}
69
70impl NthRoot for u64 {
71 const BITS: usize = 64;
72 const MAX: u64 = u64::MAX;
73
74 #[inline]
75 fn checked_mul(self, other: Self) -> Option<Self> {
76 self.checked_mul(other)
77 }
78
79 #[inline]
80 fn mul_usize(self, other: usize) -> Self { self * (other as u64) }
81
82 #[inline]
83 fn div_usize(self, other: usize) -> Self { self / (other as u64) }
84
85 #[inline]
86 fn bits(self) -> usize { (u64::BITS - self.leading_zeros()) as usize }
87
88 #[inline]
89 fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
90 if self == 0 {
91 return InitRoot::Done(0);
92 }
93 if self < 1 << N::USIZE {
94 return InitRoot::Done(1);
95 }
96
97 if N::USIZE == 2 {
98 let ans = (self as f64).sqrt() as u64;
99 if check_answer::<2>(self, ans) {
100 return InitRoot::Done(ans);
101 }
102 return InitRoot::Init(ans + 1);
103 }
104
105 if N::USIZE == 4 {
106 let ans = (self as f64).sqrt().sqrt() as u64;
107 if check_answer::<4>(self, ans) {
108 return InitRoot::Done(ans);
109 }
110 return InitRoot::Init(ans + 1);
111 }
112
113 if N::LOOKUP_BITS > 0 {
114 if self < (1 << N::LOOKUP_BITS) - 1 {
115 InitRoot::Done(N::nth_root_lookup(self))
116 } else {
117 let (small, rot) =
118 self.truncate(N::LOOKUP_BITS as usize, N::USIZE);
119 InitRoot::Init(
120 (N::nth_root_lookup(small) + 1) << (rot / N::USIZE),
121 )
122 }
123 } else {
124 InitRoot::Init(
125 ((self as f64).ln() / f64::from(N::U32)).exp() as u64 + 1,
126 )
127 }
128 }
129}
130
131impl NthRoot for u128 {
132 const BITS: usize = 128;
133 const MAX: u128 = u128::MAX;
134
135 #[inline]
136 fn checked_mul(self, other: Self) -> Option<Self> {
137 self.checked_mul(other)
138 }
139
140 #[inline]
141 fn mul_usize(self, other: usize) -> Self { self * (other as u128) }
142
143 #[inline]
144 fn div_usize(self, other: usize) -> Self { self / (other as u128) }
145
146 #[inline]
147 fn bits(self) -> usize { (u128::BITS - self.leading_zeros()) as usize }
148
149 #[inline]
150 fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
151 let compute_next = |me: u128| (me as u64).nth_root::<N>() as u128;
152 if self < u64::MAX as u128 {
153 InitRoot::Done(compute_next(self))
154 } else {
155 InitRoot::Init({
156 let (next, rot) = self.truncate(64, N::USIZE);
157 (compute_next(next) + 1) << (rot / N::USIZE)
158 })
159 }
160 }
161}
162
163impl NthRoot for U256 {
164 const BITS: usize = 256;
165 const MAX: U256 = U256::MAX;
166
167 #[inline]
168 fn checked_mul(self, other: Self) -> Option<Self> {
169 self.checked_mul(other)
170 }
171
172 #[inline]
173 fn mul_usize(self, other: usize) -> Self { self * other }
174
175 #[inline]
176 fn div_usize(self, other: usize) -> Self { self / other }
177
178 #[inline]
179 fn bits(self) -> usize { U256::bits(&self) }
180
181 #[inline]
182 fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
183 let compute_next = |me: U256| U256::from(me.as_u128().nth_root::<N>());
184 if &self.0[2..4] == &[0, 0] {
185 InitRoot::Done(compute_next(self))
186 } else {
187 InitRoot::Init({
188 let (next, rot) = self.truncate(128, N::USIZE);
189 (compute_next(next) + 1) << (rot / N::USIZE)
190 })
191 }
192 }
193}
194
195impl NthRoot for U512 {
196 const BITS: usize = 512;
197 const MAX: U512 = U512::MAX;
198
199 #[inline]
200 fn checked_mul(self, other: Self) -> Option<Self> {
201 self.checked_mul(other)
202 }
203
204 #[inline]
205 fn mul_usize(self, other: usize) -> Self { self * other }
206
207 #[inline]
208 fn div_usize(self, other: usize) -> Self { self / other }
209
210 #[inline]
211 fn bits(self) -> usize { U512::bits(&self) }
212
213 #[inline]
214 fn init_root<N: RootDegree>(self) -> InitRoot<Self> {
215 let compute_next =
216 |me: U512| U512::from(U256::try_from(me).unwrap().nth_root::<N>());
217 if &self.0[4..8] == &[0, 0, 0, 0] {
218 InitRoot::Done(compute_next(self))
219 } else {
220 InitRoot::Init({
221 let (next, rot) = self.truncate(256, N::USIZE);
222 (compute_next(next) + 1) << (rot / N::USIZE)
223 })
224 }
225 }
226}
227
228#[inline]
229fn newtons_method<N: RootDegree, I: NthRoot>(input: I, init_root: I) -> I {
230 let mut root = init_root;
231 loop {
232 let pow_n_1 = pow::<<N as SubU1>::Output, I>(root);
233 let pow_n = pow_n_1.checked_mul(root);
234
235 if pow_n.map_or(false, |x| x <= input) {
236 return root;
237 }
238
239 let mut fast_compute_root = None;
240 if I::BITS == 256 {
241 if let Some(pow_n) = pow_n {
242 let divisor = pow_n_1.mul_usize(N::USIZE);
243 fast_compute_root = Some(
244 root - (pow_n - input - 1.into()) / divisor - 1.into(),
245 );
246 }
247 }
248 root = if let Some(root) = fast_compute_root {
249 root
250 } else {
251 (input / pow_n_1 + root.mul_usize(N::USIZE - 1)).div_usize(N::USIZE)
252 };
253 }
254}
255
256#[inline]
257#[allow(unused_assignments)]
258#[unroll_for_loops]
259pub(super) fn pow<N: Unsigned, I: Copy + From<u64> + Mul<Output = I>>(
260 input: I,
261) -> I {
262 let pow = N::USIZE;
263 match pow {
264 0 => {
265 return I::from(1u64);
266 }
267 1 => {
268 return input;
269 }
270 2 => {
271 return input * input;
272 }
273 3 => {
274 return input * input * input;
275 }
276 _ => {}
277 }
278
279 let mut base = input;
280 let mut acc = I::from(1u64);
281
282 for bit in 0u32..32 {
283 if (pow & (1 << bit)) > 0 {
284 acc = acc * base
285 }
286 if bit < 31 && (pow >> (bit + 1)) > 0 {
287 base = base * base
288 }
289 }
290 acc
291}