1use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5 FancyArithmetic, FancyBinary,
6 errors::FancyError,
7 fancy::bundle::{Bundle, BundleGadgets},
8 util,
9};
10use itertools::Itertools;
11use std::ops::Deref;
12
13#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18 pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20 CrtBundle(Bundle::new(ws))
21 }
22
23 pub fn extract(self) -> Bundle<W> {
25 self.0
26 }
27
28 pub fn composite_modulus(&self) -> u128 {
30 util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31 }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35 type Target = Bundle<W>;
36
37 fn deref(&self) -> &Bundle<W> {
38 &self.0
39 }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43 fn from(b: Bundle<W>) -> CrtBundle<W> {
44 CrtBundle(b)
45 }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50pub trait CrtGadgets:
52 FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54 fn crt_constant_bundle(
57 &mut self,
58 x: u128,
59 q: u128,
60 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
61 let ps = util::factor(q);
62 let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
63 self.constant_bundle(&xs, &ps).map(CrtBundle)
64 }
65
66 fn crt_output(&mut self, x: &CrtBundle<Self::Item>) -> Result<Option<u128>, Self::Error> {
68 let q = x.composite_modulus();
69 Ok(self
70 .output_bundle(x)?
71 .map(|xs| util::crt_inv_factor(&xs, q)))
72 }
73
74 fn crt_outputs(
76 &mut self,
77 xs: &[CrtBundle<Self::Item>],
78 ) -> Result<Option<Vec<u128>>, Self::Error> {
79 let mut zs = Vec::with_capacity(xs.len());
80 for x in xs.iter() {
81 let z = self.crt_output(x)?;
82 zs.push(z);
83 }
84 Ok(zs.into_iter().collect())
85 }
86
87 fn crt_add(
92 &mut self,
93 x: &CrtBundle<Self::Item>,
94 y: &CrtBundle<Self::Item>,
95 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
96 self.add_bundles(x, y).map(CrtBundle)
97 }
98
99 fn crt_sub(
101 &mut self,
102 x: &CrtBundle<Self::Item>,
103 y: &CrtBundle<Self::Item>,
104 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
105 self.sub_bundles(x, y).map(CrtBundle)
106 }
107
108 fn crt_cmul(
110 &mut self,
111 x: &CrtBundle<Self::Item>,
112 c: u128,
113 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
114 let cs = util::crt(c, &x.moduli());
115 x.wires()
116 .iter()
117 .zip(cs.into_iter())
118 .map(|(x, c)| self.cmul(x, c))
119 .collect::<Result<Vec<Self::Item>, Self::Error>>()
120 .map(CrtBundle::new)
121 }
122
123 fn crt_mul(
125 &mut self,
126 x: &CrtBundle<Self::Item>,
127 y: &CrtBundle<Self::Item>,
128 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
129 self.mul_bundles(x, y).map(CrtBundle)
130 }
131
132 fn crt_cexp(
134 &mut self,
135 x: &CrtBundle<Self::Item>,
136 c: u16,
137 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
138 x.wires()
139 .iter()
140 .map(|x| {
141 let p = x.modulus();
142 let tab = (0..p)
143 .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
144 .collect_vec();
145 self.proj(x, p, Some(tab))
146 })
147 .collect::<Result<Vec<Self::Item>, Self::Error>>()
148 .map(CrtBundle::new)
149 }
150
151 fn crt_rem(
153 &mut self,
154 x: &CrtBundle<Self::Item>,
155 p: u16,
156 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
157 let i = x.moduli().iter().position(|&q| p == q).ok_or_else(|| {
158 Self::Error::from(FancyError::InvalidArg(
159 "p is not a modulus in this bundle!".to_string(),
160 ))
161 })?;
162 let w = &x.wires()[i];
163 x.moduli()
164 .iter()
165 .map(|&q| self.mod_change(w, q))
166 .collect::<Result<Vec<Self::Item>, Self::Error>>()
167 .map(CrtBundle::new)
168 }
169
170 fn crt_fractional_mixed_radix(
176 &mut self,
177 bun: &CrtBundle<Self::Item>,
178 ms: &[u16],
179 ) -> Result<Self::Item, Self::Error> {
180 let ndigits = ms.len();
181
182 let q = util::product(&bun.moduli());
183 let M = util::product(ms);
184
185 let mut ds = Vec::new();
186
187 for wire in bun.wires().iter() {
188 let p = wire.modulus();
189
190 let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
191
192 for x in 0..p {
193 let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
194 let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
195 let digits = util::as_mixed_radix(y, ms);
196 for i in 0..ndigits {
197 tabs[i].push(digits[i]);
198 }
199 }
200
201 let new_ds = tabs
202 .into_iter()
203 .enumerate()
204 .map(|(i, tt)| self.proj(wire, ms[i], Some(tt)))
205 .collect::<Result<Vec<Self::Item>, Self::Error>>()?;
206
207 ds.push(Bundle::new(new_ds));
208 }
209
210 self.mixed_radix_addition_msb_only(&ds)
211 }
212
213 fn crt_relu(
217 &mut self,
218 x: &CrtBundle<Self::Item>,
219 accuracy: &str,
220 output_moduli: Option<&[u16]>,
221 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
222 let factors_of_m = &get_ms(x, accuracy);
223 let res = self.crt_fractional_mixed_radix(x, factors_of_m)?;
224
225 let p = *factors_of_m.last().unwrap();
227 let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
228 let mask = self.proj(&res, 2, Some(mask_tt))?;
229
230 output_moduli
232 .map(|ps| x.with_moduli(ps))
233 .as_ref()
234 .unwrap_or(x)
235 .wires()
236 .iter()
237 .map(|x| self.mul(x, &mask))
238 .collect::<Result<Vec<Self::Item>, Self::Error>>()
239 .map(CrtBundle::new)
240 }
241
242 fn crt_sign(
244 &mut self,
245 x: &CrtBundle<Self::Item>,
246 accuracy: &str,
247 ) -> Result<Self::Item, Self::Error> {
248 let factors_of_m = &get_ms(x, accuracy);
249 let res = self.crt_fractional_mixed_radix(x, factors_of_m)?;
250 let p = *factors_of_m.last().unwrap();
251 let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
252 self.proj(&res, 2, Some(tt))
253 }
254
255 fn crt_sgn(
259 &mut self,
260 x: &CrtBundle<Self::Item>,
261 accuracy: &str,
262 output_moduli: Option<&[u16]>,
263 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
264 let sign = self.crt_sign(x, accuracy)?;
265 output_moduli
266 .unwrap_or(&x.moduli())
267 .iter()
268 .map(|&p| {
269 let tt = vec![1, p - 1];
270 self.proj(&sign, p, Some(tt))
271 })
272 .collect::<Result<Vec<Self::Item>, Self::Error>>()
273 .map(CrtBundle::new)
274 }
275
276 fn crt_lt(
278 &mut self,
279 x: &CrtBundle<Self::Item>,
280 y: &CrtBundle<Self::Item>,
281 accuracy: &str,
282 ) -> Result<Self::Item, Self::Error> {
283 let z = self.crt_sub(x, y)?;
284 self.crt_sign(&z, accuracy)
285 }
286
287 fn crt_geq(
289 &mut self,
290 x: &CrtBundle<Self::Item>,
291 y: &CrtBundle<Self::Item>,
292 accuracy: &str,
293 ) -> Result<Self::Item, Self::Error> {
294 let z = self.crt_lt(x, y, accuracy)?;
295 self.negate(&z)
296 }
297
298 fn crt_max(
300 &mut self,
301 xs: &[CrtBundle<Self::Item>],
302 accuracy: &str,
303 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
304 if xs.is_empty() {
305 return Err(Self::Error::from(FancyError::InvalidArgNum {
306 got: xs.len(),
307 needed: 1,
308 }));
309 }
310 xs.iter().skip(1).fold(Ok(xs[0].clone()), |x, y| {
311 x.map(|x| {
312 let pos = self.crt_lt(&x, y, accuracy)?;
313 let neg = self.negate(&pos)?;
314 x.wires()
315 .iter()
316 .zip(y.wires().iter())
317 .map(|(x, y)| {
318 let xp = self.mul(x, &neg)?;
319 let yp = self.mul(y, &pos)?;
320 self.add(&xp, &yp)
321 })
322 .collect::<Result<Vec<Self::Item>, Self::Error>>()
323 .map(CrtBundle::new)
324 })?
325 })
326 }
327
328 fn crt_to_pmr(
330 &mut self,
331 xs: &CrtBundle<Self::Item>,
332 ) -> Result<Bundle<Self::Item>, Self::Error> {
333 let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
334 let pq = p as u32 + q as u32 - 1;
335 let mut tab = Vec::with_capacity(pq as usize);
336 for z in 0..pq {
337 let mut x = 0;
338 let mut y = 0;
339 'outer: for i in 0..p as u32 {
340 for j in 0..q as u32 {
341 if (i + pq - j) % pq == z {
342 x = i;
343 y = j;
344 break 'outer;
345 }
346 }
347 }
348 debug_assert_eq!((x + pq - y) % pq, z);
349 tab.push(
350 (((x * q as u32 * util::inv(q as i128, p as i128) as u32
351 + y * p as u32 * util::inv(p as i128, q as i128) as u32)
352 / p as u32)
353 % q as u32) as u16,
354 );
355 }
356 tab
357 };
358
359 let mut gadget = |x: &Self::Item, y: &Self::Item| -> Result<Self::Item, Self::Error> {
360 let p = x.modulus();
361 let q = y.modulus();
362 let x_ = self.mod_change(x, p + q - 1)?;
363 let y_ = self.mod_change(y, p + q - 1)?;
364 let z = self.sub(&x_, &y_)?;
365 self.proj(&z, q, Some(gadget_projection_tt(p, q)))
366 };
367
368 let n = xs.size();
369 let mut x = vec![vec![None; n + 1]; n + 1];
370
371 for j in 0..n {
372 x[0][j + 1] = Some(xs.wires()[j].clone());
373 }
374
375 for i in 1..=n {
376 for j in i + 1..=n {
377 let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
378 x[i][j] = Some(z);
379 }
380 }
381
382 let mut zwires = Vec::with_capacity(n);
383 for i in 0..n {
384 zwires.push(x[i][i + 1].take().unwrap());
385 }
386 Ok(Bundle::new(zwires))
387 }
388
389 fn pmr_lt(
395 &mut self,
396 x: &CrtBundle<Self::Item>,
397 y: &CrtBundle<Self::Item>,
398 ) -> Result<Self::Item, Self::Error> {
399 let z = self.crt_sub(x, y)?;
400 let mut pmr = self.crt_to_pmr(&z)?;
401 let w = pmr.pop().unwrap();
402 let mut tab = vec![1; w.modulus() as usize];
403 tab[0] = 0;
404 self.proj(&w, 2, Some(tab))
405 }
406
407 fn pmr_geq(
413 &mut self,
414 x: &CrtBundle<Self::Item>,
415 y: &CrtBundle<Self::Item>,
416 ) -> Result<Self::Item, Self::Error> {
417 let z = self.pmr_lt(x, y)?;
418 self.negate(&z)
419 }
420
421 fn crt_div(
424 &mut self,
425 x: &CrtBundle<Self::Item>,
426 y: &CrtBundle<Self::Item>,
427 ) -> Result<CrtBundle<Self::Item>, Self::Error> {
428 if x.moduli() != y.moduli() {
429 return Err(Self::Error::from(FancyError::UnequalModuli));
430 }
431
432 let q = x.composite_modulus();
433
434 let nprimes = x.moduli().len();
436 let qs_ = &x.moduli()[..nprimes - 1];
437 let q_ = util::product(qs_);
438 let l = 128 - q_.leading_zeros();
439
440 let mut quotient = self.crt_constant_bundle(0, q)?;
441 let mut a = x.clone();
442
443 let one = self.crt_constant_bundle(1, q)?;
444 for i in 0..l {
445 let b = 2u128.pow(l - i - 1);
446 let mut pb = q_ / b;
447 if q_ % b == 0 {
448 pb -= 1;
449 }
450
451 let tmp = self.crt_cmul(y, b)?;
452 let c1 = self.pmr_geq(&a, &tmp)?;
453
454 let pb_crt = self.crt_constant_bundle(pb, q)?;
455 let c2 = self.pmr_geq(&pb_crt, y)?;
456
457 let c = self.and(&c1, &c2)?;
458
459 let c_ws = one
460 .iter()
461 .map(|w| self.mul(w, &c))
462 .collect::<Result<Vec<_>, _>>()?;
463 let c_crt = CrtBundle::new(c_ws);
464
465 let b_if = self.crt_cmul(&c_crt, b)?;
466 quotient = self.crt_add("ient, &b_if)?;
467
468 let tmp_if = self.crt_mul(&c_crt, &tmp)?;
469 a = self.crt_sub(&a, &tmp_if)?;
470 }
471
472 Ok(quotient)
473 }
474}
475
476fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
481 match accuracy {
482 "100%" => match x.moduli().len() {
483 3 => vec![2; 5],
484 4 => vec![3, 26],
485 5 => vec![3, 4, 54],
486 6 => vec![5, 5, 5, 60],
487 7 => vec![5, 6, 6, 7, 86],
488 8 => vec![5, 7, 8, 8, 9, 98],
489 9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
490 10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
491 11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
492 n => panic!("unknown exact Ms for {} primes!", n),
493 },
494 "99.999%" => match x.moduli().len() {
495 8 => vec![5, 5, 6, 7, 102],
496 9 => vec![5, 5, 6, 7, 114],
497 10 => vec![5, 6, 6, 7, 102],
498 11 => vec![5, 5, 6, 7, 130],
499 n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
500 },
501 "99.99%" => match x.moduli().len() {
502 6 => vec![5, 5, 5, 42],
503 7 => vec![4, 5, 6, 88],
504 8 => vec![4, 5, 7, 78],
505 9 => vec![5, 5, 6, 84],
506 10 => vec![4, 5, 6, 112],
507 11 => vec![7, 11, 174],
508 n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
509 },
510 "99.9%" => match x.moduli().len() {
511 5 => vec![3, 5, 30],
512 6 => vec![4, 5, 48],
513 7 => vec![4, 5, 60],
514 8 => vec![3, 5, 78],
515 9 => vec![9, 140],
516 10 => vec![7, 190],
517 n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
518 },
519 "99%" => match x.moduli().len() {
520 4 => vec![3, 18],
521 5 => vec![3, 36],
522 6 => vec![3, 40],
523 7 => vec![3, 40],
524 8 => vec![126],
525 9 => vec![138],
526 10 => vec![140],
527 n => panic!("unknown 99% accurate Ms for {} primes!", n),
528 },
529 _ => panic!("get_ms: unsupported accuracy {}", accuracy),
530 }
531}