1use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5 FancyArithmetic, FancyBinary,
6 fancy::bundle::{Bundle, BundleGadgets},
7 util,
8};
9use itertools::Itertools;
10use std::ops::Deref;
11use swanky_channel::Channel;
12
13#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18 pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20 CrtBundle(Bundle::new(ws))
21 }
22
23 pub fn extract(self) -> Bundle<W> {
25 self.0
26 }
27
28 pub fn composite_modulus(&self) -> u128 {
30 util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31 }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35 type Target = Bundle<W>;
36
37 fn deref(&self) -> &Bundle<W> {
38 &self.0
39 }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43 fn from(b: Bundle<W>) -> CrtBundle<W> {
44 CrtBundle(b)
45 }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50pub trait CrtGadgets:
52 FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54 fn crt_constant_bundle(
57 &mut self,
58 x: u128,
59 q: u128,
60 channel: &mut Channel,
61 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
62 let ps = util::factor(q);
63 let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
64 self.constant_bundle(&xs, &ps, channel).map(CrtBundle)
65 }
66
67 fn crt_output(
69 &mut self,
70 x: &CrtBundle<Self::Item>,
71 channel: &mut Channel,
72 ) -> swanky_error::Result<Option<u128>> {
73 let q = x.composite_modulus();
74 Ok(self
75 .output_bundle(x, channel)?
76 .map(|xs| util::crt_inv_factor(&xs, q)))
77 }
78
79 fn crt_outputs(
81 &mut self,
82 xs: &[CrtBundle<Self::Item>],
83 channel: &mut Channel,
84 ) -> swanky_error::Result<Option<Vec<u128>>> {
85 let mut zs = Vec::with_capacity(xs.len());
86 for x in xs.iter() {
87 let z = self.crt_output(x, channel)?;
88 zs.push(z);
89 }
90 Ok(zs.into_iter().collect())
91 }
92
93 fn crt_add(
98 &mut self,
99 x: &CrtBundle<Self::Item>,
100 y: &CrtBundle<Self::Item>,
101 ) -> CrtBundle<Self::Item> {
102 CrtBundle(self.add_bundles(x, y))
103 }
104
105 fn crt_sub(
107 &mut self,
108 x: &CrtBundle<Self::Item>,
109 y: &CrtBundle<Self::Item>,
110 ) -> CrtBundle<Self::Item> {
111 CrtBundle(self.sub_bundles(x, y))
112 }
113
114 fn crt_cmul(&mut self, x: &CrtBundle<Self::Item>, c: u128) -> CrtBundle<Self::Item> {
116 let cs = util::crt(c, &x.moduli());
117 CrtBundle::new(
118 x.wires()
119 .iter()
120 .zip(cs)
121 .map(|(x, c)| self.cmul(x, c))
122 .collect::<Vec<Self::Item>>(),
123 )
124 }
125
126 fn crt_mul(
128 &mut self,
129 x: &CrtBundle<Self::Item>,
130 y: &CrtBundle<Self::Item>,
131 channel: &mut Channel,
132 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
133 self.mul_bundles(x, y, channel).map(CrtBundle)
134 }
135
136 fn crt_cexp(
138 &mut self,
139 x: &CrtBundle<Self::Item>,
140 c: u16,
141 channel: &mut Channel,
142 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
143 x.wires()
144 .iter()
145 .map(|x| {
146 let p = x.modulus();
147 let tab = (0..p)
148 .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
149 .collect_vec();
150 self.proj(x, p, Some(tab), channel)
151 })
152 .collect::<swanky_error::Result<_>>()
153 .map(CrtBundle::new)
154 }
155
156 fn crt_rem(
161 &mut self,
162 x: &CrtBundle<Self::Item>,
163 p: u16,
164 channel: &mut Channel,
165 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
166 let i = x.moduli().iter().position(|&q| p == q);
167 assert!(i.is_some(), "`p` is not a modulus in the `x` bundle");
168 let i = i.unwrap();
169 let w = &x.wires()[i];
170 x.moduli()
171 .iter()
172 .map(|&q| self.mod_change(w, q, channel))
173 .collect::<swanky_error::Result<_>>()
174 .map(CrtBundle::new)
175 }
176
177 fn crt_fractional_mixed_radix(
183 &mut self,
184 bun: &CrtBundle<Self::Item>,
185 ms: &[u16],
186 channel: &mut Channel,
187 ) -> swanky_error::Result<Self::Item> {
188 let ndigits = ms.len();
189
190 let q = util::product(&bun.moduli());
191 let M = util::product(ms);
192
193 let mut ds = Vec::new();
194
195 for wire in bun.wires().iter() {
196 let p = wire.modulus();
197
198 let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
199
200 for x in 0..p {
201 let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
202 let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
203 let digits = util::as_mixed_radix(y, ms);
204 for i in 0..ndigits {
205 tabs[i].push(digits[i]);
206 }
207 }
208
209 let new_ds = tabs
210 .into_iter()
211 .enumerate()
212 .map(|(i, tt)| self.proj(wire, ms[i], Some(tt), channel))
213 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
214
215 ds.push(Bundle::new(new_ds));
216 }
217
218 self.mixed_radix_addition_msb_only(&ds, channel)
219 }
220
221 fn crt_relu(
225 &mut self,
226 x: &CrtBundle<Self::Item>,
227 accuracy: &str,
228 output_moduli: Option<&[u16]>,
229 channel: &mut Channel,
230 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
231 let factors_of_m = &get_ms(x, accuracy);
232 let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
233
234 let p = *factors_of_m.last().unwrap();
236 let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
237 let mask = self.proj(&res, 2, Some(mask_tt), channel)?;
238
239 output_moduli
241 .map(|ps| x.with_moduli(ps))
242 .as_ref()
243 .unwrap_or(x)
244 .wires()
245 .iter()
246 .map(|x| self.mul(x, &mask, channel))
247 .collect::<swanky_error::Result<_>>()
248 .map(CrtBundle::new)
249 }
250
251 fn crt_sign(
253 &mut self,
254 x: &CrtBundle<Self::Item>,
255 accuracy: &str,
256 channel: &mut Channel,
257 ) -> swanky_error::Result<Self::Item> {
258 let factors_of_m = &get_ms(x, accuracy);
259 let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
260 let p = *factors_of_m.last().unwrap();
261 let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
262 self.proj(&res, 2, Some(tt), channel)
263 }
264
265 fn crt_sgn(
269 &mut self,
270 x: &CrtBundle<Self::Item>,
271 accuracy: &str,
272 output_moduli: Option<&[u16]>,
273 channel: &mut Channel,
274 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
275 let sign = self.crt_sign(x, accuracy, channel)?;
276 output_moduli
277 .unwrap_or(&x.moduli())
278 .iter()
279 .map(|&p| {
280 let tt = vec![1, p - 1];
281 self.proj(&sign, p, Some(tt), channel)
282 })
283 .collect::<swanky_error::Result<_>>()
284 .map(CrtBundle::new)
285 }
286
287 fn crt_lt(
289 &mut self,
290 x: &CrtBundle<Self::Item>,
291 y: &CrtBundle<Self::Item>,
292 accuracy: &str,
293 channel: &mut Channel,
294 ) -> swanky_error::Result<Self::Item> {
295 let z = self.crt_sub(x, y);
296 self.crt_sign(&z, accuracy, channel)
297 }
298
299 fn crt_geq(
301 &mut self,
302 x: &CrtBundle<Self::Item>,
303 y: &CrtBundle<Self::Item>,
304 accuracy: &str,
305 channel: &mut Channel,
306 ) -> swanky_error::Result<Self::Item> {
307 let z = self.crt_lt(x, y, accuracy, channel)?;
308 Ok(self.negate(&z))
309 }
310
311 fn crt_max(
316 &mut self,
317 xs: &[CrtBundle<Self::Item>],
318 accuracy: &str,
319 channel: &mut Channel,
320 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
321 assert!(!xs.is_empty(), "`xs` cannot be empty");
322 xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
323 let pos = self.crt_lt(&x, y, accuracy, channel)?;
324 let neg = self.negate(&pos);
325 Ok(CrtBundle::new(
326 x.wires()
327 .iter()
328 .zip(y.wires().iter())
329 .map(|(x, y)| {
330 let xp = self.mul(x, &neg, channel)?;
331 let yp = self.mul(y, &pos, channel)?;
332 Ok(self.add(&xp, &yp))
333 })
334 .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
335 ))
336 })
337 }
338
339 fn crt_to_pmr(
341 &mut self,
342 xs: &CrtBundle<Self::Item>,
343 channel: &mut Channel,
344 ) -> swanky_error::Result<Bundle<Self::Item>> {
345 let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
346 let pq = p as u32 + q as u32 - 1;
347 let mut tab = Vec::with_capacity(pq as usize);
348 for z in 0..pq {
349 let mut x = 0;
350 let mut y = 0;
351 'outer: for i in 0..p as u32 {
352 for j in 0..q as u32 {
353 if (i + pq - j) % pq == z {
354 x = i;
355 y = j;
356 break 'outer;
357 }
358 }
359 }
360 debug_assert_eq!((x + pq - y) % pq, z);
361 tab.push(
362 (((x * q as u32 * util::inv(q as i128, p as i128) as u32
363 + y * p as u32 * util::inv(p as i128, q as i128) as u32)
364 / p as u32)
365 % q as u32) as u16,
366 );
367 }
368 tab
369 };
370
371 let mut gadget = |x: &Self::Item, y: &Self::Item| -> swanky_error::Result<Self::Item> {
372 let p = x.modulus();
373 let q = y.modulus();
374 let x_ = self.mod_change(x, p + q - 1, channel)?;
375 let y_ = self.mod_change(y, p + q - 1, channel)?;
376 let z = self.sub(&x_, &y_);
377 self.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
378 };
379
380 let n = xs.size();
381 let mut x = vec![vec![None; n + 1]; n + 1];
382
383 for j in 0..n {
384 x[0][j + 1] = Some(xs.wires()[j].clone());
385 }
386
387 for i in 1..=n {
388 for j in i + 1..=n {
389 let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
390 x[i][j] = Some(z);
391 }
392 }
393
394 let mut zwires = Vec::with_capacity(n);
395 for i in 0..n {
396 zwires.push(x[i][i + 1].take().unwrap());
397 }
398 Ok(Bundle::new(zwires))
399 }
400
401 fn pmr_lt(
407 &mut self,
408 x: &CrtBundle<Self::Item>,
409 y: &CrtBundle<Self::Item>,
410 channel: &mut Channel,
411 ) -> swanky_error::Result<Self::Item> {
412 let z = self.crt_sub(x, y);
413 let mut pmr = self.crt_to_pmr(&z, channel)?;
414 let w = pmr.pop().unwrap();
415 let mut tab = vec![1; w.modulus() as usize];
416 tab[0] = 0;
417 self.proj(&w, 2, Some(tab), channel)
418 }
419
420 fn pmr_geq(
426 &mut self,
427 x: &CrtBundle<Self::Item>,
428 y: &CrtBundle<Self::Item>,
429 channel: &mut Channel,
430 ) -> swanky_error::Result<Self::Item> {
431 let z = self.pmr_lt(x, y, channel)?;
432 Ok(self.negate(&z))
433 }
434
435 fn crt_div(
441 &mut self,
442 x: &CrtBundle<Self::Item>,
443 y: &CrtBundle<Self::Item>,
444 channel: &mut Channel,
445 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
446 assert_eq!(x.moduli(), y.moduli());
447
448 let q = x.composite_modulus();
449
450 let nprimes = x.moduli().len();
452 let qs_ = &x.moduli()[..nprimes - 1];
453 let q_ = util::product(qs_);
454 let l = 128 - q_.leading_zeros();
455
456 let mut quotient = self.crt_constant_bundle(0, q, channel)?;
457 let mut a = x.clone();
458
459 let one = self.crt_constant_bundle(1, q, channel)?;
460 for i in 0..l {
461 let b = 2u128.pow(l - i - 1);
462 let mut pb = q_ / b;
463 if q_.is_multiple_of(b) {
464 pb -= 1;
465 }
466
467 let tmp = self.crt_cmul(y, b);
468 let c1 = self.pmr_geq(&a, &tmp, channel)?;
469
470 let pb_crt = self.crt_constant_bundle(pb, q, channel)?;
471 let c2 = self.pmr_geq(&pb_crt, y, channel)?;
472
473 let c = self.and(&c1, &c2, channel)?;
474
475 let c_ws = one
476 .iter()
477 .map(|w| self.mul(w, &c, channel))
478 .collect::<Result<Vec<_>, _>>()?;
479 let c_crt = CrtBundle::new(c_ws);
480
481 let b_if = self.crt_cmul(&c_crt, b);
482 quotient = self.crt_add("ient, &b_if);
483
484 let tmp_if = self.crt_mul(&c_crt, &tmp, channel)?;
485 a = self.crt_sub(&a, &tmp_if);
486 }
487
488 Ok(quotient)
489 }
490}
491
492fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
497 match accuracy {
498 "100%" => match x.moduli().len() {
499 3 => vec![2; 5],
500 4 => vec![3, 26],
501 5 => vec![3, 4, 54],
502 6 => vec![5, 5, 5, 60],
503 7 => vec![5, 6, 6, 7, 86],
504 8 => vec![5, 7, 8, 8, 9, 98],
505 9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
506 10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
507 11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
508 n => panic!("unknown exact Ms for {} primes!", n),
509 },
510 "99.999%" => match x.moduli().len() {
511 8 => vec![5, 5, 6, 7, 102],
512 9 => vec![5, 5, 6, 7, 114],
513 10 => vec![5, 6, 6, 7, 102],
514 11 => vec![5, 5, 6, 7, 130],
515 n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
516 },
517 "99.99%" => match x.moduli().len() {
518 6 => vec![5, 5, 5, 42],
519 7 => vec![4, 5, 6, 88],
520 8 => vec![4, 5, 7, 78],
521 9 => vec![5, 5, 6, 84],
522 10 => vec![4, 5, 6, 112],
523 11 => vec![7, 11, 174],
524 n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
525 },
526 "99.9%" => match x.moduli().len() {
527 5 => vec![3, 5, 30],
528 6 => vec![4, 5, 48],
529 7 => vec![4, 5, 60],
530 8 => vec![3, 5, 78],
531 9 => vec![9, 140],
532 10 => vec![7, 190],
533 n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
534 },
535 "99%" => match x.moduli().len() {
536 4 => vec![3, 18],
537 5 => vec![3, 36],
538 6 => vec![3, 40],
539 7 => vec![3, 40],
540 8 => vec![126],
541 9 => vec![138],
542 10 => vec![140],
543 n => panic!("unknown 99% accurate Ms for {} primes!", n),
544 },
545 _ => panic!("get_ms: unsupported accuracy {}", accuracy),
546 }
547}