1use crate::WireLabel;
6use itertools::Itertools;
7use std::collections::HashMap;
8use vectoreyes::U8x16;
9
10pub fn tweak(i: usize) -> u128 {
15 i as u128
16}
17
18pub fn tweak2(i: u64, j: u64) -> u128 {
20 (j as u128) << 64 | (i as u128)
21}
22
23pub fn output_tweak(i: usize, k: u16) -> u128 {
25 let (left, _) = (i as u128).overflowing_shl(64);
26 left + k as u128
27}
28
29pub fn base_q_add_eq(xs: &mut [u16], ys: &[u16], q: u16) {
34 debug_assert!(
35 xs.len() >= ys.len(),
36 "q={} xs.len()={} ys.len()={} xs={:?} ys={:?}",
37 q,
38 xs.len(),
39 ys.len(),
40 xs,
41 ys
42 );
43
44 let mut c = 0;
45 let mut i = 0;
46
47 while i < ys.len() {
48 xs[i] += ys[i] + c;
49 c = (xs[i] >= q) as u16;
50 xs[i] -= c * q;
51 i += 1;
52 }
53
54 while i < xs.len() {
56 xs[i] += c;
57 if xs[i] >= q {
58 xs[i] -= q;
59 } else {
61 break;
63 }
64 i += 1;
65 }
66}
67
68fn as_base_q(x: u128, q: u16, n: usize) -> Vec<u16> {
70 let ms = std::iter::repeat_n(q, n).collect_vec();
71 as_mixed_radix(x, &ms)
72}
73
74pub fn digits_per_u128(modulus: u16) -> usize {
77 debug_assert_ne!(modulus, 0);
78 debug_assert_ne!(modulus, 1);
79 if modulus == 2 {
80 128
81 } else if modulus <= 4 {
82 64
83 } else if modulus <= 8 {
84 42
85 } else if modulus <= 16 {
86 32
87 } else if modulus <= 32 {
88 25
89 } else if modulus <= 64 {
90 21
91 } else if modulus <= 128 {
92 18
93 } else if modulus <= 256 {
94 16
95 } else if modulus <= 512 {
96 14
97 } else {
98 (128.0 / (modulus as f64).log2().ceil()).floor() as usize
99 }
100}
101
102pub fn as_base_q_u128(x: u128, q: u16) -> Vec<u16> {
104 as_base_q(x, q, digits_per_u128(q))
105}
106
107pub fn as_mixed_radix(x: u128, radii: &[u16]) -> Vec<u16> {
109 let mut x = x;
110 radii
111 .iter()
112 .map(|&m| {
113 if x >= m as u128 {
114 let d = x % m as u128;
115 x = (x - d) / m as u128;
116 d as u16
117 } else {
118 let d = x as u16;
119 x = 0;
120 d
121 }
122 })
123 .collect()
124}
125
126pub fn from_base_q(ds: &[u16], q: u16) -> u128 {
128 let mut x = 0u128;
129 for &d in ds.iter().rev() {
130 let (xp, overflow) = x.overflowing_mul(q.into());
131 debug_assert!(!overflow, "overflow!!!! x={}", x);
132 x = xp + d as u128;
133 }
134 x
135}
136
137pub fn from_mixed_radix(digits: &[u16], radii: &[u16]) -> u128 {
139 let mut x: u128 = 0;
140 for (&d, &q) in digits.iter().zip(radii.iter()).rev() {
141 let (xp, overflow) = x.overflowing_mul(q as u128);
142 debug_assert!(!overflow, "overflow!!!! x={}", x);
143 x = xp + d as u128;
144 }
145 x
146}
147
148pub fn u128_to_bits(x: u128, n: usize) -> Vec<u16> {
154 let mut bits = Vec::with_capacity(n);
155 let mut y = x;
156 for _ in 0..n {
157 let b = y & 1;
158 bits.push(b as u16);
159 y -= b;
160 y /= 2;
161 }
162 bits
163}
164
165pub fn u128_from_bits(bs: &[u16]) -> u128 {
167 let mut x = 0;
168 for &b in bs.iter().skip(1).rev() {
169 x += b as u128;
170 x *= 2;
171 }
172 x += bs[0] as u128;
173 x
174}
175
176pub fn factor(inp: u128) -> Vec<u16> {
185 let mut x = inp;
186 let mut fs = Vec::new();
187 for &p in PRIMES.iter() {
188 let q = p as u128;
189 if x.is_multiple_of(q) {
190 fs.push(p);
191 x /= q;
192 }
193 }
194 if x != 1 {
195 panic!("can only factor numbers with unique prime factors");
196 }
197 fs
198}
199
200pub fn crt(x: u128, ps: &[u16]) -> Vec<u16> {
202 ps.iter().map(|&p| (x % p as u128) as u16).collect()
203}
204
205pub fn crt_factor(x: u128, q: u128) -> Vec<u16> {
208 crt(x, &factor(q))
209}
210
211pub fn crt_inv(xs: &[u16], ps: &[u16]) -> u128 {
213 let mut ret = 0;
214 let M = ps.iter().fold(1, |acc, &x| x as i128 * acc);
215 for (&p, &a) in ps.iter().zip(xs.iter()) {
216 let p = p as i128;
217 let q = M / p;
218 ret += a as i128 * inv(q, p) * q;
219 ret %= &M;
220 }
221 ret as u128
222}
223
224pub fn crt_inv_factor(xs: &[u16], q: u128) -> u128 {
226 crt_inv(xs, &factor(q))
227}
228
229pub fn inv(inp_a: i128, inp_b: i128) -> i128 {
231 let mut a = inp_a;
232 let mut b = inp_b;
233 let mut q;
234 let mut tmp;
235
236 let (mut x0, mut x1) = (0, 1);
237
238 if b == 1 {
239 return 1;
240 }
241
242 while a > 1 {
243 q = a / b;
244
245 tmp = b;
247 b = a % b;
248 a = tmp;
249
250 tmp = x0;
251 x0 = x1 - q * x0;
252 x1 = tmp;
253 }
254
255 if x1 < 0 {
256 x1 += inp_b;
257 }
258
259 x1
260}
261
262pub const NPRIMES: usize = 29;
264
265pub const PRIMES: [u16; 29] = [
267 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
268 101, 103, 107, 109,
269];
270
271pub fn modulus_with_nprimes(n: usize) -> u128 {
279 product(&PRIMES[0..n])
280}
281
282pub fn modulus_with_width(n: u32) -> u128 {
285 base_modulus_with_width(n, &PRIMES)
286}
287
288pub fn primes_with_width(n: u32) -> Vec<u16> {
291 base_primes_with_width(n, &PRIMES)
292}
293
294pub fn base_modulus_with_width(nbits: u32, primes: &[u16]) -> u128 {
296 product(&base_primes_with_width(nbits, primes))
297}
298
299pub fn base_primes_with_width(nbits: u32, primes: &[u16]) -> Vec<u16> {
301 let mut res = 1;
302 let mut ps = Vec::new();
303 for &p in primes.iter() {
304 res *= u128::from(p);
305 ps.push(p);
306 if (res >> nbits) > 0 {
307 break;
308 }
309 }
310 assert!((res >> nbits) > 0, "not enough primes!");
311 ps
312}
313
314pub fn product(xs: &[u16]) -> u128 {
322 xs.iter().fold(1, |acc, &x| acc * x as u128)
323}
324
325pub fn is_power_of_2(x: u16) -> bool {
344 (x & (x - 1)) == 0
345}
346
347pub fn generate_deltas<Wire: WireLabel>(primes: &[u16]) -> HashMap<u16, Wire> {
349 let mut deltas = HashMap::new();
350 let mut rng = rand::thread_rng();
351 for q in primes {
352 deltas.insert(*q, Wire::rand_delta(&mut rng, *q));
353 }
354 deltas
355}
356
357pub trait RngExt: rand::Rng + Sized {
359 fn gen_bool(&mut self) -> bool {
361 self.r#gen()
362 }
363 fn gen_u16(&mut self) -> u16 {
365 self.r#gen()
366 }
367 fn gen_u32(&mut self) -> u32 {
369 self.r#gen()
370 }
371 fn gen_u64(&mut self) -> u64 {
373 self.r#gen()
374 }
375 fn gen_usize(&mut self) -> usize {
377 self.r#gen()
378 }
379 fn gen_u128(&mut self) -> u128 {
381 self.r#gen()
382 }
383 fn gen_block(&mut self) -> U8x16 {
385 self.r#gen()
386 }
387 fn gen_usable_block(&mut self, modulus: u16) -> U8x16 {
389 if is_power_of_2(modulus) {
390 let nbits = (modulus - 1).count_ones();
391 if 128 % nbits == 0 {
392 return U8x16::from(self.gen_u128());
393 }
394 }
395 let n = digits_per_u128(modulus);
396 let max = (modulus as u128).pow(n as u32);
397 U8x16::from(self.gen_u128() % max)
398 }
399 fn gen_prime(&mut self) -> u16 {
401 PRIMES[self.r#gen::<usize>() % NPRIMES]
402 }
403 fn gen_modulus(&mut self) -> u16 {
405 2 + (self.r#gen::<u16>() % 111)
406 }
407 fn gen_usable_composite_modulus(&mut self) -> u128 {
409 product(&self.gen_usable_factors())
410 }
411 fn gen_usable_factors(&mut self) -> Vec<u16> {
413 let mut x: u128 = 1;
414 PRIMES[..25]
415 .iter()
416 .cloned()
417 .filter(|_| self.r#gen()) .take_while(|&q| {
419 match x.checked_mul(q as u128) {
421 None => false,
422 Some(y) => {
423 x = y;
424 true
425 }
426 }
427 })
428 .collect()
429 }
430}
431
432impl<R: rand::Rng + Sized> RngExt for R {}
433
434#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::util::RngExt;
441 use rand::thread_rng;
442
443 #[test]
444 fn crt_conversion() {
445 let mut rng = thread_rng();
446 let ps = &PRIMES[..25];
447 let modulus = product(ps);
448
449 for _ in 0..128 {
450 let x = rng.gen_u128() % modulus;
451 assert_eq!(crt_inv(&crt(x, ps), ps), x);
452 }
453 }
454
455 #[test]
456 fn factoring() {
457 let mut rng = thread_rng();
458 for _ in 0..16 {
459 let mut ps = Vec::new();
460 let mut q: u128 = 1;
461 for &p in PRIMES.iter() {
462 if rng.gen_bool() {
463 match q.checked_mul(p as u128) {
464 None => break,
465 Some(z) => q = z,
466 }
467 ps.push(p);
468 }
469 }
470 assert_eq!(factor(q), ps);
471 }
472 }
473
474 #[test]
475 fn bits() {
476 let mut rng = thread_rng();
477 for _ in 0..128 {
478 let x = rng.gen_u128();
479 assert_eq!(u128_from_bits(&u128_to_bits(x, 128)), x);
480 }
481 }
482
483 #[test]
484 fn base_q_conversion() {
485 let mut rng = thread_rng();
486 for _ in 0..1000 {
487 let q = 2 + (rng.gen_u16() % 111);
488 let x = u128::from(rng.gen_usable_block(q));
489 let y = as_base_q(x, q, digits_per_u128(q));
490 let z = from_base_q(&y, q);
491 assert_eq!(x, z);
492 }
493 }
494}