1use itertools::Itertools;
6use std::collections::HashMap;
7use swanky_block::Block;
8use vectoreyes::{SimdBase, U8x16, U64x2};
9
10use crate::WireLabel;
11
12pub fn tweak(i: usize) -> Block {
17 Block::from(U8x16::from(U64x2::set_lo(i as u64)))
18}
19
20pub fn tweak2(i: u64, j: u64) -> Block {
22 Block::from(U8x16::from(U64x2::from([j, i])))
23}
24
25pub fn output_tweak(i: usize, k: u16) -> Block {
27 let (left, _) = (i as u128).overflowing_shl(64);
28 Block::from(left + k as u128)
29}
30
31pub fn base_q_add_eq(xs: &mut [u16], ys: &[u16], q: u16) {
36 debug_assert!(
37 xs.len() >= ys.len(),
38 "q={} xs.len()={} ys.len()={} xs={:?} ys={:?}",
39 q,
40 xs.len(),
41 ys.len(),
42 xs,
43 ys
44 );
45
46 let mut c = 0;
47 let mut i = 0;
48
49 while i < ys.len() {
50 xs[i] += ys[i] + c;
51 c = (xs[i] >= q) as u16;
52 xs[i] -= c * q;
53 i += 1;
54 }
55
56 while i < xs.len() {
58 xs[i] += c;
59 if xs[i] >= q {
60 xs[i] -= q;
61 } else {
63 break;
65 }
66 i += 1;
67 }
68}
69
70fn as_base_q(x: u128, q: u16, n: usize) -> Vec<u16> {
72 let ms = std::iter::repeat(q).take(n).collect_vec();
73 as_mixed_radix(x, &ms)
74}
75
76pub fn digits_per_u128(modulus: u16) -> usize {
79 debug_assert_ne!(modulus, 0);
80 debug_assert_ne!(modulus, 1);
81 if modulus == 2 {
82 128
83 } else if modulus <= 4 {
84 64
85 } else if modulus <= 8 {
86 42
87 } else if modulus <= 16 {
88 32
89 } else if modulus <= 32 {
90 25
91 } else if modulus <= 64 {
92 21
93 } else if modulus <= 128 {
94 18
95 } else if modulus <= 256 {
96 16
97 } else if modulus <= 512 {
98 14
99 } else {
100 (128.0 / (modulus as f64).log2().ceil()).floor() as usize
101 }
102}
103
104pub fn as_base_q_u128(x: u128, q: u16) -> Vec<u16> {
106 as_base_q(x, q, digits_per_u128(q))
107}
108
109pub fn as_mixed_radix(x: u128, radii: &[u16]) -> Vec<u16> {
111 let mut x = x;
112 radii
113 .iter()
114 .map(|&m| {
115 if x >= m as u128 {
116 let d = x % m as u128;
117 x = (x - d) / m as u128;
118 d as u16
119 } else {
120 let d = x as u16;
121 x = 0;
122 d
123 }
124 })
125 .collect()
126}
127
128pub fn from_base_q(ds: &[u16], q: u16) -> u128 {
130 let mut x = 0u128;
131 for &d in ds.iter().rev() {
132 let (xp, overflow) = x.overflowing_mul(q.into());
133 debug_assert!(!overflow, "overflow!!!! x={}", x);
134 x = xp + d as u128;
135 }
136 x
137}
138
139pub fn from_mixed_radix(digits: &[u16], radii: &[u16]) -> u128 {
141 let mut x: u128 = 0;
142 for (&d, &q) in digits.iter().zip(radii.iter()).rev() {
143 let (xp, overflow) = x.overflowing_mul(q as u128);
144 debug_assert!(!overflow, "overflow!!!! x={}", x);
145 x = xp + d as u128;
146 }
147 x
148}
149
150pub fn u128_to_bits(x: u128, n: usize) -> Vec<u16> {
156 let mut bits = Vec::with_capacity(n);
157 let mut y = x;
158 for _ in 0..n {
159 let b = y & 1;
160 bits.push(b as u16);
161 y -= b;
162 y /= 2;
163 }
164 bits
165}
166
167pub fn u128_from_bits(bs: &[u16]) -> u128 {
169 let mut x = 0;
170 for &b in bs.iter().skip(1).rev() {
171 x += b as u128;
172 x *= 2;
173 }
174 x += bs[0] as u128;
175 x
176}
177
178pub fn factor(inp: u128) -> Vec<u16> {
187 let mut x = inp;
188 let mut fs = Vec::new();
189 for &p in PRIMES.iter() {
190 let q = p as u128;
191 if x % q == 0 {
192 fs.push(p);
193 x /= q;
194 }
195 }
196 if x != 1 {
197 panic!("can only factor numbers with unique prime factors");
198 }
199 fs
200}
201
202pub fn crt(x: u128, ps: &[u16]) -> Vec<u16> {
204 ps.iter().map(|&p| (x % p as u128) as u16).collect()
205}
206
207pub fn crt_factor(x: u128, q: u128) -> Vec<u16> {
210 crt(x, &factor(q))
211}
212
213pub fn crt_inv(xs: &[u16], ps: &[u16]) -> u128 {
215 let mut ret = 0;
216 let M = ps.iter().fold(1, |acc, &x| x as i128 * acc);
217 for (&p, &a) in ps.iter().zip(xs.iter()) {
218 let p = p as i128;
219 let q = M / p;
220 ret += a as i128 * inv(q, p) * q;
221 ret %= &M;
222 }
223 ret as u128
224}
225
226pub fn crt_inv_factor(xs: &[u16], q: u128) -> u128 {
228 crt_inv(xs, &factor(q))
229}
230
231pub fn inv(inp_a: i128, inp_b: i128) -> i128 {
233 let mut a = inp_a;
234 let mut b = inp_b;
235 let mut q;
236 let mut tmp;
237
238 let (mut x0, mut x1) = (0, 1);
239
240 if b == 1 {
241 return 1;
242 }
243
244 while a > 1 {
245 q = a / b;
246
247 tmp = b;
249 b = a % b;
250 a = tmp;
251
252 tmp = x0;
253 x0 = x1 - q * x0;
254 x1 = tmp;
255 }
256
257 if x1 < 0 {
258 x1 += inp_b;
259 }
260
261 x1
262}
263
264pub const NPRIMES: usize = 29;
266
267pub const PRIMES: [u16; 29] = [
269 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
270 101, 103, 107, 109,
271];
272
273pub fn modulus_with_nprimes(n: usize) -> u128 {
281 product(&PRIMES[0..n])
282}
283
284pub fn modulus_with_width(n: u32) -> u128 {
287 base_modulus_with_width(n, &PRIMES)
288}
289
290pub fn primes_with_width(n: u32) -> Vec<u16> {
293 base_primes_with_width(n, &PRIMES)
294}
295
296pub fn base_modulus_with_width(nbits: u32, primes: &[u16]) -> u128 {
298 product(&base_primes_with_width(nbits, primes))
299}
300
301pub fn base_primes_with_width(nbits: u32, primes: &[u16]) -> Vec<u16> {
303 let mut res = 1;
304 let mut ps = Vec::new();
305 for &p in primes.iter() {
306 res *= u128::from(p);
307 ps.push(p);
308 if (res >> nbits) > 0 {
309 break;
310 }
311 }
312 assert!((res >> nbits) > 0, "not enough primes!");
313 ps
314}
315
316pub fn product(xs: &[u16]) -> u128 {
324 xs.iter().fold(1, |acc, &x| acc * x as u128)
325}
326
327pub fn is_power_of_2(x: u16) -> bool {
346 (x & (x - 1)) == 0
347}
348
349pub fn generate_deltas<Wire: WireLabel>(primes: &[u16]) -> HashMap<u16, Wire> {
351 let mut deltas = HashMap::new();
352 let mut rng = rand::thread_rng();
353 for q in primes {
354 deltas.insert(*q, Wire::rand_delta(&mut rng, *q));
355 }
356 deltas
357}
358
359pub trait RngExt: rand::Rng + Sized {
361 fn gen_bool(&mut self) -> bool {
363 self.r#gen()
364 }
365 fn gen_u16(&mut self) -> u16 {
367 self.r#gen()
368 }
369 fn gen_u32(&mut self) -> u32 {
371 self.r#gen()
372 }
373 fn gen_u64(&mut self) -> u64 {
375 self.r#gen()
376 }
377 fn gen_usize(&mut self) -> usize {
379 self.r#gen()
380 }
381 fn gen_u128(&mut self) -> u128 {
383 self.r#gen()
384 }
385 fn gen_block(&mut self) -> Block {
387 self.r#gen()
388 }
389 fn gen_usable_block(&mut self, modulus: u16) -> Block {
391 if is_power_of_2(modulus) {
392 let nbits = (modulus - 1).count_ones();
393 if 128 % nbits == 0 {
394 return Block::from(self.gen_u128());
395 }
396 }
397 let n = digits_per_u128(modulus);
398 let max = (modulus as u128).pow(n as u32);
399 Block::from(self.gen_u128() % max)
400 }
401 fn gen_prime(&mut self) -> u16 {
403 PRIMES[self.r#gen::<usize>() % NPRIMES]
404 }
405 fn gen_modulus(&mut self) -> u16 {
407 2 + (self.r#gen::<u16>() % 111)
408 }
409 fn gen_usable_composite_modulus(&mut self) -> u128 {
411 product(&self.gen_usable_factors())
412 }
413 fn gen_usable_factors(&mut self) -> Vec<u16> {
415 let mut x: u128 = 1;
416 PRIMES[..25]
417 .iter()
418 .cloned()
419 .filter(|_| self.r#gen()) .take_while(|&q| {
421 match x.checked_mul(q as u128) {
423 None => false,
424 Some(y) => {
425 x = y;
426 true
427 }
428 }
429 })
430 .collect()
431 }
432}
433
434impl<R: rand::Rng + Sized> RngExt for R {}
435
436#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::util::RngExt;
443 use rand::thread_rng;
444
445 #[test]
446 fn crt_conversion() {
447 let mut rng = thread_rng();
448 let ps = &PRIMES[..25];
449 let modulus = product(ps);
450
451 for _ in 0..128 {
452 let x = rng.gen_u128() % modulus;
453 assert_eq!(crt_inv(&crt(x, ps), ps), x);
454 }
455 }
456
457 #[test]
458 fn factoring() {
459 let mut rng = thread_rng();
460 for _ in 0..16 {
461 let mut ps = Vec::new();
462 let mut q: u128 = 1;
463 for &p in PRIMES.iter() {
464 if rng.gen_bool() {
465 match q.checked_mul(p as u128) {
466 None => break,
467 Some(z) => q = z,
468 }
469 ps.push(p);
470 }
471 }
472 assert_eq!(factor(q), ps);
473 }
474 }
475
476 #[test]
477 fn bits() {
478 let mut rng = thread_rng();
479 for _ in 0..128 {
480 let x = rng.gen_u128();
481 assert_eq!(u128_from_bits(&u128_to_bits(x, 128)), x);
482 }
483 }
484
485 #[test]
486 fn base_q_conversion() {
487 let mut rng = thread_rng();
488 for _ in 0..1000 {
489 let q = 2 + (rng.gen_u16() % 111);
490 let x = u128::from(rng.gen_usable_block(q));
491 let y = as_base_q(x, q, digits_per_u128(q));
492 let z = from_base_q(&y, q);
493 assert_eq!(x, z);
494 }
495 }
496}