1use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5 FancyArithmetic, FancyBinary,
6 fancy::bundle::{ArithmeticProjBundleGadgets, Bundle, BundleGadgets},
7 util,
8};
9use itertools::Itertools;
10use std::ops::Deref;
11use swanky_channel::Channel;
12
13#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18 pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20 CrtBundle(Bundle::new(ws))
21 }
22
23 pub fn extract(self) -> Bundle<W> {
25 self.0
26 }
27
28 pub fn composite_modulus(&self) -> u128 {
30 util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31 }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35 type Target = Bundle<W>;
36
37 fn deref(&self) -> &Bundle<W> {
38 &self.0
39 }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43 fn from(b: Bundle<W>) -> CrtBundle<W> {
44 CrtBundle(b)
45 }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50pub trait CrtGadgets:
52 FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54 fn crt_encode(
56 &mut self,
57 value: u128,
58 modulus: u128,
59 channel: &mut Channel,
60 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
61 let qs = util::factor(modulus);
62 let xs = util::crt(value, &qs);
63 self.encode_bundle(&xs, &qs, channel).map(CrtBundle::from)
64 }
65
66 fn crt_receive(
68 &mut self,
69 modulus: u128,
70 channel: &mut Channel,
71 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
72 let qs = util::factor(modulus);
73 self.receive_bundle(&qs, channel).map(CrtBundle::from)
74 }
75
76 fn crt_encode_many(
78 &mut self,
79 values: &[u128],
80 modulus: u128,
81 channel: &mut Channel,
82 ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
83 let mods = util::factor(modulus);
84 let nmods = mods.len();
85 let xs = values
86 .iter()
87 .flat_map(|x| util::crt(*x, &mods))
88 .collect_vec();
89 let qs = itertools::repeat_n(mods, values.len())
90 .flatten()
91 .collect_vec();
92 let mut wires = self.encode_many(&xs, &qs, channel)?;
93 let buns = (0..values.len())
94 .map(|_| {
95 let ws = wires.drain(0..nmods).collect_vec();
96 CrtBundle::new(ws)
97 })
98 .collect_vec();
99 Ok(buns)
100 }
101
102 fn crt_receive_many(
104 &mut self,
105 n: usize,
106 modulus: u128,
107 channel: &mut Channel,
108 ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
109 let mods = util::factor(modulus);
110 let nmods = mods.len();
111 let qs = itertools::repeat_n(mods, n).flatten().collect_vec();
112 let mut wires = self.receive_many(&qs, channel)?;
113 let buns = (0..n)
114 .map(|_| {
115 let ws = wires.drain(0..nmods).collect_vec();
116 CrtBundle::new(ws)
117 })
118 .collect_vec();
119 Ok(buns)
120 }
121
122 fn crt_constant_bundle(
125 &mut self,
126 x: u128,
127 q: u128,
128 channel: &mut Channel,
129 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
130 let ps = util::factor(q);
131 let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
132 self.constant_bundle(&xs, &ps, channel).map(CrtBundle)
133 }
134
135 fn crt_output(
137 &mut self,
138 x: &CrtBundle<Self::Item>,
139 channel: &mut Channel,
140 ) -> swanky_error::Result<Option<u128>> {
141 let q = x.composite_modulus();
142 Ok(self
143 .output_bundle(x, channel)?
144 .map(|xs| util::crt_inv_factor(&xs, q)))
145 }
146
147 fn crt_outputs(
149 &mut self,
150 xs: &[CrtBundle<Self::Item>],
151 channel: &mut Channel,
152 ) -> swanky_error::Result<Option<Vec<u128>>> {
153 let mut zs = Vec::with_capacity(xs.len());
154 for x in xs.iter() {
155 let z = self.crt_output(x, channel)?;
156 zs.push(z);
157 }
158 Ok(zs.into_iter().collect())
159 }
160
161 fn crt_add(
166 &mut self,
167 x: &CrtBundle<Self::Item>,
168 y: &CrtBundle<Self::Item>,
169 ) -> CrtBundle<Self::Item> {
170 CrtBundle(self.add_bundles(x, y))
171 }
172
173 fn crt_sub(
175 &mut self,
176 x: &CrtBundle<Self::Item>,
177 y: &CrtBundle<Self::Item>,
178 ) -> CrtBundle<Self::Item> {
179 CrtBundle(self.sub_bundles(x, y))
180 }
181
182 fn crt_cmul(&mut self, x: &CrtBundle<Self::Item>, c: u128) -> CrtBundle<Self::Item> {
184 let cs = util::crt(c, &x.moduli());
185 CrtBundle::new(
186 x.wires()
187 .iter()
188 .zip(cs)
189 .map(|(x, c)| self.cmul(x, c))
190 .collect::<Vec<Self::Item>>(),
191 )
192 }
193
194 fn crt_mul(
196 &mut self,
197 x: &CrtBundle<Self::Item>,
198 y: &CrtBundle<Self::Item>,
199 channel: &mut Channel,
200 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
201 self.mul_bundles(x, y, channel).map(CrtBundle)
202 }
203}
204
205impl<F: ArithmeticProjBundleGadgets + CrtGadgets> CrtProjGadgets for F {}
206
207pub trait CrtProjGadgets: ArithmeticProjBundleGadgets + CrtGadgets {
209 fn crt_cexp(
211 &mut self,
212 x: &CrtBundle<Self::Item>,
213 c: u16,
214 channel: &mut Channel,
215 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
216 x.wires()
217 .iter()
218 .map(|x| {
219 let p = x.modulus();
220 let tab = (0..p)
221 .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
222 .collect_vec();
223 self.proj(x, p, Some(tab), channel)
224 })
225 .collect::<swanky_error::Result<_>>()
226 .map(CrtBundle::new)
227 }
228
229 fn crt_rem(
234 &mut self,
235 x: &CrtBundle<Self::Item>,
236 p: u16,
237 channel: &mut Channel,
238 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
239 let i = x.moduli().iter().position(|&q| p == q);
240 assert!(i.is_some(), "`p` is not a modulus in the `x` bundle");
241 let i = i.unwrap();
242 let w = &x.wires()[i];
243 x.moduli()
244 .iter()
245 .map(|&q| self.mod_change(w, q, channel))
246 .collect::<swanky_error::Result<_>>()
247 .map(CrtBundle::new)
248 }
249
250 fn crt_fractional_mixed_radix(
256 &mut self,
257 bun: &CrtBundle<Self::Item>,
258 ms: &[u16],
259 channel: &mut Channel,
260 ) -> swanky_error::Result<Self::Item> {
261 let ndigits = ms.len();
262
263 let q = util::product(&bun.moduli());
264 let M = util::product(ms);
265
266 let mut ds = Vec::new();
267
268 for wire in bun.wires().iter() {
269 let p = wire.modulus();
270
271 let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
272
273 for x in 0..p {
274 let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
275 let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
276 let digits = util::as_mixed_radix(y, ms);
277 for i in 0..ndigits {
278 tabs[i].push(digits[i]);
279 }
280 }
281
282 let new_ds = tabs
283 .into_iter()
284 .enumerate()
285 .map(|(i, tt)| self.proj(wire, ms[i], Some(tt), channel))
286 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
287
288 ds.push(Bundle::new(new_ds));
289 }
290
291 self.mixed_radix_addition_msb_only(&ds, channel)
292 }
293
294 fn crt_relu(
298 &mut self,
299 x: &CrtBundle<Self::Item>,
300 accuracy: &str,
301 output_moduli: Option<&[u16]>,
302 channel: &mut Channel,
303 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
304 let factors_of_m = &get_ms(x, accuracy);
305 let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
306
307 let p = *factors_of_m.last().unwrap();
309 let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
310 let mask = self.proj(&res, 2, Some(mask_tt), channel)?;
311
312 output_moduli
314 .map(|ps| x.with_moduli(ps))
315 .as_ref()
316 .unwrap_or(x)
317 .wires()
318 .iter()
319 .map(|x| self.mul(x, &mask, channel))
320 .collect::<swanky_error::Result<_>>()
321 .map(CrtBundle::new)
322 }
323
324 fn crt_sign(
326 &mut self,
327 x: &CrtBundle<Self::Item>,
328 accuracy: &str,
329 channel: &mut Channel,
330 ) -> swanky_error::Result<Self::Item> {
331 let factors_of_m = &get_ms(x, accuracy);
332 let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
333 let p = *factors_of_m.last().unwrap();
334 let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
335 self.proj(&res, 2, Some(tt), channel)
336 }
337
338 fn crt_sgn(
342 &mut self,
343 x: &CrtBundle<Self::Item>,
344 accuracy: &str,
345 output_moduli: Option<&[u16]>,
346 channel: &mut Channel,
347 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
348 let sign = self.crt_sign(x, accuracy, channel)?;
349 output_moduli
350 .unwrap_or(&x.moduli())
351 .iter()
352 .map(|&p| {
353 let tt = vec![1, p - 1];
354 self.proj(&sign, p, Some(tt), channel)
355 })
356 .collect::<swanky_error::Result<_>>()
357 .map(CrtBundle::new)
358 }
359
360 fn crt_lt(
362 &mut self,
363 x: &CrtBundle<Self::Item>,
364 y: &CrtBundle<Self::Item>,
365 accuracy: &str,
366 channel: &mut Channel,
367 ) -> swanky_error::Result<Self::Item> {
368 let z = self.crt_sub(x, y);
369 self.crt_sign(&z, accuracy, channel)
370 }
371
372 fn crt_geq(
374 &mut self,
375 x: &CrtBundle<Self::Item>,
376 y: &CrtBundle<Self::Item>,
377 accuracy: &str,
378 channel: &mut Channel,
379 ) -> swanky_error::Result<Self::Item> {
380 let z = self.crt_lt(x, y, accuracy, channel)?;
381 Ok(self.negate(&z))
382 }
383
384 fn crt_max(
389 &mut self,
390 xs: &[CrtBundle<Self::Item>],
391 accuracy: &str,
392 channel: &mut Channel,
393 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
394 assert!(!xs.is_empty(), "`xs` cannot be empty");
395 xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
396 let pos = self.crt_lt(&x, y, accuracy, channel)?;
397 let neg = self.negate(&pos);
398 Ok(CrtBundle::new(
399 x.wires()
400 .iter()
401 .zip(y.wires().iter())
402 .map(|(x, y)| {
403 let xp = self.mul(x, &neg, channel)?;
404 let yp = self.mul(y, &pos, channel)?;
405 Ok(self.add(&xp, &yp))
406 })
407 .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
408 ))
409 })
410 }
411
412 fn crt_to_pmr(
414 &mut self,
415 xs: &CrtBundle<Self::Item>,
416 channel: &mut Channel,
417 ) -> swanky_error::Result<Bundle<Self::Item>> {
418 let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
419 let pq = p as u32 + q as u32 - 1;
420 let mut tab = Vec::with_capacity(pq as usize);
421 for z in 0..pq {
422 let mut x = 0;
423 let mut y = 0;
424 'outer: for i in 0..p as u32 {
425 for j in 0..q as u32 {
426 if (i + pq - j) % pq == z {
427 x = i;
428 y = j;
429 break 'outer;
430 }
431 }
432 }
433 debug_assert_eq!((x + pq - y) % pq, z);
434 tab.push(
435 (((x * q as u32 * util::inv(q as i128, p as i128) as u32
436 + y * p as u32 * util::inv(p as i128, q as i128) as u32)
437 / p as u32)
438 % q as u32) as u16,
439 );
440 }
441 tab
442 };
443
444 let mut gadget = |x: &Self::Item, y: &Self::Item| -> swanky_error::Result<Self::Item> {
445 let p = x.modulus();
446 let q = y.modulus();
447 let x_ = self.mod_change(x, p + q - 1, channel)?;
448 let y_ = self.mod_change(y, p + q - 1, channel)?;
449 let z = self.sub(&x_, &y_);
450 self.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
451 };
452
453 let n = xs.size();
454 let mut x = vec![vec![None; n + 1]; n + 1];
455
456 for j in 0..n {
457 x[0][j + 1] = Some(xs.wires()[j].clone());
458 }
459
460 for i in 1..=n {
461 for j in i + 1..=n {
462 let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
463 x[i][j] = Some(z);
464 }
465 }
466
467 let mut zwires = Vec::with_capacity(n);
468 for i in 0..n {
469 zwires.push(x[i][i + 1].take().unwrap());
470 }
471 Ok(Bundle::new(zwires))
472 }
473
474 fn pmr_lt(
480 &mut self,
481 x: &CrtBundle<Self::Item>,
482 y: &CrtBundle<Self::Item>,
483 channel: &mut Channel,
484 ) -> swanky_error::Result<Self::Item> {
485 let z = self.crt_sub(x, y);
486 let mut pmr = self.crt_to_pmr(&z, channel)?;
487 let w = pmr.pop().unwrap();
488 let mut tab = vec![1; w.modulus() as usize];
489 tab[0] = 0;
490 self.proj(&w, 2, Some(tab), channel)
491 }
492
493 fn pmr_geq(
499 &mut self,
500 x: &CrtBundle<Self::Item>,
501 y: &CrtBundle<Self::Item>,
502 channel: &mut Channel,
503 ) -> swanky_error::Result<Self::Item> {
504 let z = self.pmr_lt(x, y, channel)?;
505 Ok(self.negate(&z))
506 }
507
508 fn crt_div(
514 &mut self,
515 x: &CrtBundle<Self::Item>,
516 y: &CrtBundle<Self::Item>,
517 channel: &mut Channel,
518 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
519 assert_eq!(x.moduli(), y.moduli());
520
521 let q = x.composite_modulus();
522
523 let nprimes = x.moduli().len();
525 let qs_ = &x.moduli()[..nprimes - 1];
526 let q_ = util::product(qs_);
527 let l = 128 - q_.leading_zeros();
528
529 let mut quotient = self.crt_constant_bundle(0, q, channel)?;
530 let mut a = x.clone();
531
532 let one = self.crt_constant_bundle(1, q, channel)?;
533 for i in 0..l {
534 let b = 2u128.pow(l - i - 1);
535 let mut pb = q_ / b;
536 if q_.is_multiple_of(b) {
537 pb -= 1;
538 }
539
540 let tmp = self.crt_cmul(y, b);
541 let c1 = self.pmr_geq(&a, &tmp, channel)?;
542
543 let pb_crt = self.crt_constant_bundle(pb, q, channel)?;
544 let c2 = self.pmr_geq(&pb_crt, y, channel)?;
545
546 let c = self.and(&c1, &c2, channel)?;
547
548 let c_ws = one
549 .iter()
550 .map(|w| self.mul(w, &c, channel))
551 .collect::<Result<Vec<_>, _>>()?;
552 let c_crt = CrtBundle::new(c_ws);
553
554 let b_if = self.crt_cmul(&c_crt, b);
555 quotient = self.crt_add("ient, &b_if);
556
557 let tmp_if = self.crt_mul(&c_crt, &tmp, channel)?;
558 a = self.crt_sub(&a, &tmp_if);
559 }
560
561 Ok(quotient)
562 }
563}
564
565fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
570 match accuracy {
571 "100%" => match x.moduli().len() {
572 3 => vec![2; 5],
573 4 => vec![3, 26],
574 5 => vec![3, 4, 54],
575 6 => vec![5, 5, 5, 60],
576 7 => vec![5, 6, 6, 7, 86],
577 8 => vec![5, 7, 8, 8, 9, 98],
578 9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
579 10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
580 11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
581 n => panic!("unknown exact Ms for {} primes!", n),
582 },
583 "99.999%" => match x.moduli().len() {
584 8 => vec![5, 5, 6, 7, 102],
585 9 => vec![5, 5, 6, 7, 114],
586 10 => vec![5, 6, 6, 7, 102],
587 11 => vec![5, 5, 6, 7, 130],
588 n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
589 },
590 "99.99%" => match x.moduli().len() {
591 6 => vec![5, 5, 5, 42],
592 7 => vec![4, 5, 6, 88],
593 8 => vec![4, 5, 7, 78],
594 9 => vec![5, 5, 6, 84],
595 10 => vec![4, 5, 6, 112],
596 11 => vec![7, 11, 174],
597 n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
598 },
599 "99.9%" => match x.moduli().len() {
600 5 => vec![3, 5, 30],
601 6 => vec![4, 5, 48],
602 7 => vec![4, 5, 60],
603 8 => vec![3, 5, 78],
604 9 => vec![9, 140],
605 10 => vec![7, 190],
606 n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
607 },
608 "99%" => match x.moduli().len() {
609 4 => vec![3, 18],
610 5 => vec![3, 36],
611 6 => vec![3, 40],
612 7 => vec![3, 40],
613 8 => vec![126],
614 9 => vec![138],
615 10 => vec![140],
616 n => panic!("unknown 99% accurate Ms for {} primes!", n),
617 },
618 _ => panic!("get_ms: unsupported accuracy {}", accuracy),
619 }
620}