1use swanky_channel::Channel;
7use swanky_error::ErrorKind;
8
9use crate::{
10 FancyArithmetic, FancyBinary, check_binary,
11 fancy::{Fancy, FancyInput, FancyReveal, HasModulus},
12};
13
14pub struct Dummy {}
16
17#[derive(Clone, Debug)]
19pub struct DummyVal {
20 val: u16,
21 modulus: u16,
22}
23
24impl HasModulus for DummyVal {
25 fn modulus(&self) -> u16 {
26 self.modulus
27 }
28}
29
30impl DummyVal {
31 pub fn new(val: u16, modulus: u16) -> Self {
33 Self { val, modulus }
34 }
35
36 pub fn val(&self) -> u16 {
38 self.val
39 }
40}
41
42impl Dummy {
43 pub fn new() -> Dummy {
45 Dummy {}
46 }
47}
48
49impl Default for Dummy {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl FancyInput for Dummy {
56 type Item = DummyVal;
57
58 fn encode(
60 &mut self,
61 value: u16,
62 modulus: u16,
63 _: &mut Channel,
64 ) -> swanky_error::Result<DummyVal> {
65 Ok(DummyVal::new(value, modulus))
66 }
67
68 fn encode_many(
70 &mut self,
71 xs: &[u16],
72 moduli: &[u16],
73 _: &mut Channel,
74 ) -> swanky_error::Result<Vec<DummyVal>> {
75 assert_eq!(xs.len(), moduli.len());
76 Ok(xs
77 .iter()
78 .zip(moduli.iter())
79 .map(|(x, q)| DummyVal::new(*x, *q))
80 .collect())
81 }
82
83 fn receive_many(
84 &mut self,
85 _moduli: &[u16],
86 _: &mut Channel,
87 ) -> swanky_error::Result<Vec<DummyVal>> {
88 swanky_error::bail!(
90 ErrorKind::UnsupportedError,
91 "`receive_many` is undefined for `Dummy`"
92 );
93 }
94}
95
96impl FancyBinary for Dummy {
97 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
98 check_binary!(x);
99 check_binary!(y);
100
101 self.add(x, y)
102 }
103
104 fn and(
105 &mut self,
106 x: &Self::Item,
107 y: &Self::Item,
108 channel: &mut Channel,
109 ) -> swanky_error::Result<Self::Item> {
110 check_binary!(x);
111 check_binary!(y);
112
113 self.mul(x, y, channel)
114 }
115
116 fn negate(&mut self, x: &Self::Item) -> Self::Item {
117 check_binary!(x);
118
119 self.xor(x, &DummyVal::new(1, 2))
120 }
121}
122
123impl FancyArithmetic for Dummy {
124 fn add(&mut self, x: &DummyVal, y: &DummyVal) -> DummyVal {
125 assert_eq!(x.modulus(), y.modulus());
126 DummyVal {
127 val: (x.val + y.val) % x.modulus,
128 modulus: x.modulus,
129 }
130 }
131
132 fn sub(&mut self, x: &DummyVal, y: &DummyVal) -> DummyVal {
133 assert_eq!(x.modulus(), y.modulus());
134 DummyVal {
135 val: (x.modulus + x.val - y.val) % x.modulus,
136 modulus: x.modulus,
137 }
138 }
139
140 fn cmul(&mut self, x: &DummyVal, c: u16) -> DummyVal {
141 DummyVal {
142 val: (x.val * c) % x.modulus,
143 modulus: x.modulus,
144 }
145 }
146
147 fn mul(
148 &mut self,
149 x: &DummyVal,
150 y: &DummyVal,
151 _: &mut Channel,
152 ) -> swanky_error::Result<DummyVal> {
153 Ok(DummyVal {
154 val: x.val * y.val % x.modulus,
155 modulus: x.modulus,
156 })
157 }
158
159 fn proj(
160 &mut self,
161 x: &DummyVal,
162 modulus: u16,
163 tt: Option<Vec<u16>>,
164 _: &mut Channel,
165 ) -> swanky_error::Result<DummyVal> {
166 assert!(tt.is_some(), "`tt` must not be `None`");
167 let tt = tt.unwrap();
168 assert!(
169 tt.len() >= x.modulus() as usize,
170 "`tt` not large enough for `x`s modulus"
171 );
172 assert!(
173 tt.iter().all(|&x| x < modulus),
174 "`tt` value larger than `q`"
175 );
176 let val = tt[x.val as usize];
177 Ok(DummyVal { val, modulus })
178 }
179}
180
181impl Fancy for Dummy {
182 type Item = DummyVal;
183
184 fn constant(
185 &mut self,
186 val: u16,
187 modulus: u16,
188 _: &mut Channel,
189 ) -> swanky_error::Result<DummyVal> {
190 Ok(DummyVal { val, modulus })
191 }
192
193 fn output(&mut self, x: &DummyVal, _: &mut Channel) -> swanky_error::Result<Option<u16>> {
194 Ok(Some(x.val))
195 }
196}
197
198impl FancyReveal for Dummy {
199 fn reveal(&mut self, x: &DummyVal, _: &mut Channel) -> swanky_error::Result<u16> {
200 Ok(x.val)
201 }
202}
203
204#[cfg(test)]
205mod bundle {
206 use super::*;
207 use crate::{
208 fancy::{ArithmeticBundleGadgets, BinaryGadgets, Bundle, BundleGadgets, CrtGadgets},
209 util::{self, RngExt},
210 };
211 use itertools::Itertools;
212 use rand::thread_rng;
213
214 const NITERS: usize = 1 << 10;
215
216 #[test]
217 fn test_addition() {
218 let mut rng = thread_rng();
219 for _ in 0..NITERS {
220 let q = rng.gen_usable_composite_modulus();
221 let x = rng.gen_u128() % q;
222 let y = rng.gen_u128() % q;
223 let mut d = Dummy::new();
224 let out = Channel::with(std::io::empty(), |channel| {
225 let x = d.crt_encode(x, q, channel).unwrap();
226 let y = d.crt_encode(y, q, channel).unwrap();
227 let z = d.crt_add(&x, &y);
228 Ok(d.crt_output(&z, channel).unwrap().unwrap())
229 })
230 .unwrap();
231 assert_eq!(out, (x + y) % q);
232 }
233 }
234
235 #[test]
236 fn test_subtraction() {
237 let mut rng = thread_rng();
238 for _ in 0..NITERS {
239 let q = rng.gen_usable_composite_modulus();
240 let x = rng.gen_u128() % q;
241 let y = rng.gen_u128() % q;
242 let mut d = Dummy::new();
243 let out = Channel::with(std::io::empty(), |channel| {
244 let x = d.crt_encode(x, q, channel).unwrap();
245 let y = d.crt_encode(y, q, channel).unwrap();
246 let z = d.crt_sub(&x, &y);
247 Ok(d.crt_output(&z, channel).unwrap().unwrap())
248 })
249 .unwrap();
250 assert_eq!(out, (x + q - y) % q);
251 }
252 }
253
254 #[test]
255 fn test_binary_cmul() {
256 let mut rng = thread_rng();
257 for _ in 0..NITERS {
258 let nbits = 64;
259 let q = 1 << nbits;
260 let x = rng.gen_u128() % q;
261 let c = 1 + rng.gen_u128() % q;
262 let mut d = Dummy::new();
263 let out = Channel::with(std::io::empty(), |channel| {
264 let x = d.bin_encode(x, nbits, channel).unwrap();
265 let z = d.bin_cmul(&x, c, nbits, channel).unwrap();
266 Ok(d.bin_output(&z, channel).unwrap().unwrap())
267 })
268 .unwrap();
269 assert_eq!(out, (x * c) % q);
270 }
271 }
272
273 #[test]
274 fn test_binary_multiplication() {
275 let mut rng = thread_rng();
276 for _ in 0..NITERS {
277 let nbits = 64;
278 let q = 1 << nbits;
279 let x = rng.gen_u128() % q;
280 let y = rng.gen_u128() % q;
281 let mut d = Dummy::new();
282 let out = Channel::with(std::io::empty(), |channel| {
283 let x = d.bin_encode(x, nbits, channel).unwrap();
284 let y = d.bin_encode(y, nbits, channel).unwrap();
285 let z = d.bin_multiplication_lower_half(&x, &y, channel).unwrap();
286 let out = d.bin_output(&z, channel).unwrap().unwrap();
287 Ok(out)
288 })
289 .unwrap();
290 assert_eq!(out, (x * y) % q);
291 }
292 }
293
294 #[test]
295 fn test_shift_extend() {
296 let mut rng = thread_rng();
297 for _ in 0..NITERS {
298 let nbits = 64;
299 let q = 1 << nbits;
300 let shift_size = rng.gen_usize() % nbits;
301 let x = rng.gen_u128() % q;
302 let mut d = Dummy::new();
303 let out = Channel::with(std::io::empty(), |channel| {
304 use crate::BinaryBundle;
305 let x = d.bin_encode(x, nbits, channel).unwrap();
306 let z = d.shift_extend(&x, shift_size, channel).unwrap();
307 Ok(d.bin_output(&BinaryBundle::from(z), channel)
308 .unwrap()
309 .unwrap())
310 })
311 .unwrap();
312 assert_eq!(out, x << shift_size);
313 }
314 }
315
316 #[test]
317 fn test_binary_full_multiplication() {
318 let mut rng = thread_rng();
319 for _ in 0..NITERS {
320 let nbits = 64;
321 let q = 1 << nbits;
322 let x = rng.gen_u128() % q;
323 let y = rng.gen_u128() % q;
324 let mut d = Dummy::new();
325 let out = Channel::with(std::io::empty(), |channel| {
326 let x = d.bin_encode(x, nbits, channel).unwrap();
327 let y = d.bin_encode(y, nbits, channel).unwrap();
328 let z = d.bin_mul(&x, &y, channel).unwrap();
329 println!("z.len() = {}", z.size());
330 Ok(d.bin_output(&z, channel).unwrap().unwrap())
331 })
332 .unwrap();
333 assert_eq!(out, x * y);
334 }
335 }
336
337 #[test]
338 fn test_binary_division() {
339 let mut rng = thread_rng();
340 for _ in 0..NITERS {
341 let nbits = 64;
342 let q = 1 << nbits;
343 let x = rng.gen_u128() % q;
344 let y = rng.gen_u128() % q;
345 let mut d = Dummy::new();
346 let out = Channel::with(std::io::empty(), |channel| {
347 let x = d.bin_encode(x, nbits, channel).unwrap();
348 let y = d.bin_encode(y, nbits, channel).unwrap();
349 let z = d.bin_div(&x, &y, channel).unwrap();
350 Ok(d.bin_output(&z, channel).unwrap().unwrap())
351 })
352 .unwrap();
353 assert_eq!(out, x / y);
354 }
355 }
356
357 #[test]
358 fn max() {
359 let mut rng = thread_rng();
360 let q = util::modulus_with_width(10);
361 let n = 10;
362 for _ in 0..NITERS {
363 let inps = (0..n).map(|_| rng.gen_u128() % (q / 2)).collect_vec();
364 let should_be = *inps.iter().max().unwrap();
365 let mut d = Dummy::new();
366 let out = Channel::with(std::io::empty(), |channel| {
367 let xs = inps
368 .into_iter()
369 .map(|x| d.crt_encode(x, q, channel).unwrap())
370 .collect_vec();
371 let z = d.crt_max(&xs, "100%", channel).unwrap();
372 Ok(d.crt_output(&z, channel).unwrap().unwrap())
373 })
374 .unwrap();
375 assert_eq!(out, should_be);
376 }
377 }
378
379 #[test]
380 fn twos_complement() {
381 let mut rng = thread_rng();
382 let nbits = 16;
383 let q = 1 << nbits;
384 for _ in 0..NITERS {
385 let x = rng.gen_u128() % q;
386 let should_be = (((!x) % q) + 1) % q;
387 let mut d = Dummy::new();
388 let out = Channel::with(std::io::empty(), |channel| {
389 let x = d.bin_encode(x, nbits, channel).unwrap();
390 let y = d.bin_twos_complement(&x, channel).unwrap();
391 Ok(d.bin_output(&y, channel).unwrap().unwrap())
392 })
393 .unwrap();
394 assert_eq!(out, should_be, "x={} y={} should_be={}", x, out, should_be);
395 }
396 }
397
398 #[test]
399 fn binary_addition() {
400 let mut rng = thread_rng();
401 let nbits = 16;
402 let q = 1 << nbits;
403 for _ in 0..NITERS {
404 let x = rng.gen_u128() % q;
405 let y = rng.gen_u128() % q;
406 let should_be = (x + y) % q;
407 let mut d = Dummy::new();
408 let (out, overflow) = Channel::with(std::io::empty(), |channel| {
409 let x = d.bin_encode(x, nbits, channel).unwrap();
410 let y = d.bin_encode(y, nbits, channel).unwrap();
411 let (z, _overflow) = d.bin_addition(&x, &y, channel).unwrap();
412 let overflow = d.output(&_overflow, channel).unwrap().unwrap();
413 let out = d.bin_output(&z, channel).unwrap().unwrap();
414 Ok((out, overflow))
415 })
416 .unwrap();
417 assert_eq!(out, should_be);
418 assert_eq!(overflow > 0, x + y >= q);
419 }
420 }
421
422 #[test]
423 fn binary_subtraction() {
424 let mut rng = thread_rng();
425 let nbits = 16;
426 let q = 1 << nbits;
427 for _ in 0..NITERS {
428 let x = rng.gen_u128() % q;
429 let y = rng.gen_u128() % q;
430 let (should_be, _) = x.overflowing_sub(y);
431 let should_be = should_be % q;
432 let mut d = Dummy::new();
433 let (out, overflow) = Channel::with(std::io::empty(), |channel| {
434 let x = d.bin_encode(x, nbits, channel).unwrap();
435 let y = d.bin_encode(y, nbits, channel).unwrap();
436 let (z, _overflow) = d.bin_subtraction(&x, &y, channel).unwrap();
437 let overflow = d.output(&_overflow, channel).unwrap().unwrap();
438 let out = d.bin_output(&z, channel).unwrap().unwrap();
439 Ok((out, overflow))
440 })
441 .unwrap();
442 assert_eq!(out, should_be);
443 assert_eq!(overflow > 0, (y != 0 && x >= y), "x={} y={}", x, y);
444 }
445 }
446
447 #[test]
448 fn binary_lt() {
449 let mut rng = thread_rng();
450 let nbits = 16;
451 let q = 1 << nbits;
452 for _ in 0..NITERS {
453 let x = rng.gen_u128() % q;
454 let y = rng.gen_u128() % q;
455 let should_be = x < y;
456 let mut d = Dummy::new();
457 let out = Channel::with(std::io::empty(), |channel| {
458 let x = d.bin_encode(x, nbits, channel).unwrap();
459 let y = d.bin_encode(y, nbits, channel).unwrap();
460 let z = d.bin_lt(&x, &y, channel).unwrap();
461 Ok(d.output(&z, channel).unwrap().unwrap())
462 })
463 .unwrap();
464 assert_eq!(out > 0, should_be, "x={} y={}", x, y);
465 }
466 }
467
468 #[test]
469 fn binary_lt_signed() {
470 let mut rng = thread_rng();
471 let nbits = 16;
472 let q = 1 << nbits;
473 for _ in 0..NITERS {
474 let x = rng.gen_u128() % q;
475 let y = rng.gen_u128() % q;
476 let should_be = (x as i16) < (y as i16);
477 let mut d = Dummy::new();
478 let out = Channel::with(std::io::empty(), |channel| {
479 let x = d.bin_encode(x, nbits, channel).unwrap();
480 let y = d.bin_encode(y, nbits, channel).unwrap();
481 let z = d.bin_lt_signed(&x, &y, channel).unwrap();
482 Ok(d.output(&z, channel).unwrap().unwrap())
483 })
484 .unwrap();
485 assert_eq!(out > 0, should_be, "x={} y={}", x as i16, y as i16);
486 }
487 }
488
489 #[test]
490 fn binary_max() {
491 let mut rng = thread_rng();
492 let n = 10;
493 let nbits = 16;
494 let q = 1 << nbits;
495 for _ in 0..NITERS {
496 let inps = (0..n).map(|_| rng.gen_u128() % q).collect_vec();
497 let should_be = *inps.iter().max().unwrap();
498 let mut d = Dummy::new();
499 let out = Channel::with(std::io::empty(), |channel| {
500 let xs = inps
501 .into_iter()
502 .map(|x| d.bin_encode(x, nbits, channel).unwrap())
503 .collect_vec();
504 let z = d.bin_max(&xs, channel).unwrap();
505 Ok(d.bin_output(&z, channel).unwrap().unwrap())
506 })
507 .unwrap();
508 assert_eq!(out, should_be);
509 }
510 }
511
512 #[test] fn test_relu() {
514 let mut rng = thread_rng();
515 for _ in 0..NITERS {
516 let q = crate::util::modulus_with_nprimes(4 + rng.gen_usize() % 7); let x = rng.gen_u128() % q;
518 let mut d = Dummy::new();
519 let out = Channel::with(std::io::empty(), |channel| {
520 let x = d.crt_encode(x, q, channel).unwrap();
521 let z = d.crt_relu(&x, "100%", None, channel).unwrap();
522 Ok(d.crt_output(&z, channel).unwrap().unwrap())
523 })
524 .unwrap();
525 if x >= q / 2 {
526 assert_eq!(out, 0);
527 } else {
528 assert_eq!(out, x);
529 }
530 }
531 }
532
533 #[test]
534 fn test_mask() {
535 let mut rng = thread_rng();
536 for _ in 0..NITERS {
537 let q = crate::util::modulus_with_nprimes(4 + rng.gen_usize() % 7);
538 let x = rng.gen_u128() % q;
539 let b = rng.gen_bool();
540 let mut d = Dummy::new();
541 let out = Channel::with(std::io::empty(), |channel| {
542 let b = d.encode(b as u16, 2, channel).unwrap();
543 let x = d.crt_encode(x, q, channel).unwrap();
544 let z = d.mask(&b, &x, channel).unwrap().into();
545 Ok(d.crt_output(&z, channel).unwrap().unwrap())
546 })
547 .unwrap();
548 assert!(
549 if b { out == x } else { out == 0 },
550 "b={} x={} z={}",
551 b,
552 x,
553 out
554 );
555 }
556 }
557
558 #[test]
559 fn binary_abs() {
560 let mut rng = thread_rng();
561 for _ in 0..NITERS {
562 let nbits = 64;
563 let q = 1 << nbits;
564 let x = rng.gen_u128() % q;
565 let mut d = Dummy::new();
566 let out = Channel::with(std::io::empty(), |channel| {
567 let x = d.bin_encode(x, nbits, channel).unwrap();
568 let z = d.bin_abs(&x, channel).unwrap();
569 Ok(d.bin_output(&z, channel).unwrap().unwrap())
570 })
571 .unwrap();
572 let should_be = if x >> (nbits - 1) > 0 {
573 ((!x) + 1) & ((1 << nbits) - 1)
574 } else {
575 x
576 };
577 assert_eq!(out, should_be);
578 }
579 }
580
581 #[test]
582 fn binary_demux() {
583 let mut rng = thread_rng();
584 for _ in 0..NITERS {
585 let nbits = 8;
586 let q = 1 << nbits;
587 let x = rng.gen_u128() % q;
588 let mut d = Dummy::new();
589 let outs = Channel::with(std::io::empty(), |channel| {
590 let x = d.bin_encode(x, nbits, channel).unwrap();
591 let zs = d.bin_demux(&x, channel).unwrap();
592 Ok(d.outputs(&zs, channel).unwrap().unwrap())
593 })
594 .unwrap();
595 for (i, z) in outs.into_iter().enumerate() {
596 if i as u128 == x {
597 assert_eq!(z, 1);
598 } else {
599 assert_eq!(z, 0);
600 }
601 }
602 }
603 }
604
605 #[test]
606 fn binary_eq() {
607 let mut rng = thread_rng();
608 for _ in 0..NITERS {
609 let nbits = rng.gen_usize() % 100 + 2;
610 let q = 1 << nbits;
611 let x = rng.gen_u128() % q;
612 let y = if rng.gen_bool() {
613 x
614 } else {
615 rng.gen_u128() % q
616 };
617 let mut d = Dummy::new();
618 let out = Channel::with(std::io::empty(), |channel| {
619 let x = d.bin_encode(x, nbits, channel).unwrap();
620 let y = d.bin_encode(y, nbits, channel).unwrap();
621 let z = d.bin_eq_bundles(&x, &y, channel).unwrap();
622 Ok(d.output(&z, channel).unwrap().unwrap())
623 })
624 .unwrap();
625 assert_eq!(out, (x == y) as u16);
626 }
627 }
628
629 #[test]
630 fn binary_proj_eq() {
631 let mut rng = thread_rng();
632 for _ in 0..NITERS {
633 let nbits = rng.gen_usize() % 100 + 2;
634 let q = 1 << nbits;
635 let x = rng.gen_u128() % q;
636 let y = if rng.gen_bool() {
637 x
638 } else {
639 rng.gen_u128() % q
640 };
641 let mut d = Dummy::new();
642 let out = Channel::with(std::io::empty(), |channel| {
643 let x = d.bin_encode(x, nbits, channel).unwrap();
644 let y = d.bin_encode(y, nbits, channel).unwrap();
645 let z = d.eq_bundles(&x, &y, channel).unwrap();
646 Ok(d.output(&z, channel).unwrap().unwrap())
647 })
648 .unwrap();
649 assert_eq!(out, (x == y) as u16);
650 }
651 }
652
653 #[test]
654 fn binary_rsa() {
655 let mut rng = thread_rng();
656 for _ in 0..NITERS {
657 let nbits = 64;
658 let q = 1 << nbits;
659 let x = rng.gen_u128() % q;
660 let shift_size = rng.gen_usize() % nbits;
661 let mut d = Dummy::new();
662 let out = Channel::with(std::io::empty(), |channel| {
663 let x = d.bin_encode(x, nbits, channel).unwrap();
664 let z = d.bin_rsa(&x, shift_size);
665 Ok(d.bin_output(&z, channel).unwrap().unwrap() as i64)
666 })
667 .unwrap();
668 let should_be = (x as i64) >> shift_size;
669 assert_eq!(out, should_be);
670 }
671 }
672
673 #[test]
674 fn binary_rsl() {
675 let mut rng = thread_rng();
676 for _ in 0..NITERS {
677 let nbits = 64;
678 let q = 1 << nbits;
679 let x = rng.gen_u128() % q;
680 let shift_size = rng.gen_usize() % nbits;
681 let mut d = Dummy::new();
682 let out = Channel::with(std::io::empty(), |channel| {
683 let x = d.bin_encode(x, nbits, channel).unwrap();
684 let z = d.bin_rsl(&x, shift_size, channel).unwrap();
685 Ok(d.bin_output(&z, channel).unwrap().unwrap())
686 })
687 .unwrap();
688 let should_be = x >> shift_size;
689 assert_eq!(out, should_be);
690 }
691 }
692
693 #[test]
694 fn test_mixed_radix_addition_msb_only() {
695 let mut rng = thread_rng();
696 for _ in 0..NITERS {
697 let nargs = 2 + rng.gen_usize() % 10;
698 let mods = (0..7).map(|_| rng.gen_modulus()).collect_vec();
699 let Q: u128 = util::product(&mods);
700
701 println!("nargs={} mods={:?} Q={}", nargs, mods, Q);
702
703 let xs = (0..nargs)
705 .map(|_| {
706 Bundle::new(
707 util::as_mixed_radix(Q - 1, &mods)
708 .into_iter()
709 .zip(&mods)
710 .map(|(x, q)| DummyVal::new(x, *q))
711 .collect_vec(),
712 )
713 })
714 .collect_vec();
715
716 let mut d = Dummy::new();
717
718 let res = Channel::with(std::io::empty(), |channel| {
719 let z = d.mixed_radix_addition_msb_only(&xs, channel).unwrap();
720 Ok(d.output(&z, channel).unwrap().unwrap())
721 })
722 .unwrap();
723
724 let should_be = *util::as_mixed_radix((Q - 1) * (nargs as u128) % Q, &mods)
725 .last()
726 .unwrap();
727 assert_eq!(res, should_be);
728
729 for _ in 0..4 {
731 let mut sum = 0;
732
733 let xs = (0..nargs)
734 .map(|_| {
735 let x = rng.gen_u128() % Q;
736 sum = (sum + x) % Q;
737 Bundle::new(
738 util::as_mixed_radix(x, &mods)
739 .into_iter()
740 .zip(&mods)
741 .map(|(x, q)| DummyVal::new(x, *q))
742 .collect_vec(),
743 )
744 })
745 .collect_vec();
746
747 let mut d = Dummy::new();
748 let res = Channel::with(std::io::empty(), |channel| {
749 let z = d.mixed_radix_addition_msb_only(&xs, channel).unwrap();
750 Ok(d.output(&z, channel).unwrap().unwrap())
751 })
752 .unwrap();
753
754 let should_be = *util::as_mixed_radix(sum, &mods).last().unwrap();
755 assert_eq!(res, should_be);
756 }
757 }
758 }
759}
760
761#[cfg(test)]
762mod pmr_tests {
763 use super::*;
764 use crate::{
765 fancy::{BundleGadgets, CrtGadgets, FancyInput},
766 util::RngExt,
767 };
768
769 #[test]
770 fn pmr() {
771 let mut rng = rand::thread_rng();
772 for _ in 0..8 {
773 let ps = rng.gen_usable_factors();
774 let q = crate::util::product(&ps);
775 let pt = rng.gen_u128() % q;
776
777 let res = Channel::with(std::io::empty(), |channel| {
778 let mut f = Dummy::new();
779 let x = f.crt_encode(pt, q, channel).unwrap();
780 let z = f.crt_to_pmr(&x, channel).unwrap();
781 Ok(f.output_bundle(&z, channel).unwrap().unwrap())
782 })
783 .unwrap();
784
785 let should_be = to_pmr_pt(pt, &ps);
786 assert_eq!(res, should_be);
787 }
788 }
789
790 fn to_pmr_pt(x: u128, ps: &[u16]) -> Vec<u16> {
791 let mut ds = vec![0; ps.len()];
792 let mut q = 1;
793 for i in 0..ps.len() {
794 let p = ps[i] as u128;
795 ds[i] = ((x / q) % p) as u16;
796 q *= p;
797 }
798 ds
799 }
800
801 #[test]
802 fn pmr_lt() {
803 let mut rng = rand::thread_rng();
804 for _ in 0..8 {
805 let qs = rng.gen_usable_factors();
806 let n = qs.len();
807 let q = crate::util::product(&qs);
808 let q_ = crate::util::product(&qs[..n - 1]);
809 let pt_x = rng.gen_u128() % q_;
810 let pt_y = rng.gen_u128() % q_;
811
812 let res = Channel::with(std::io::empty(), |channel| {
813 let mut f = Dummy::new();
814 let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
815 let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
816 let z = f.pmr_lt(&crt_x, &crt_y, channel).unwrap();
817 Ok(f.output(&z, channel).unwrap().unwrap())
818 })
819 .unwrap();
820
821 let should_be = if pt_x < pt_y { 1 } else { 0 };
822 assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
823 }
824 }
825
826 #[test]
827 fn pmr_geq() {
828 let mut rng = rand::thread_rng();
829 for _ in 0..8 {
830 let qs = rng.gen_usable_factors();
831 let n = qs.len();
832 let q = crate::util::product(&qs);
833 let q_ = crate::util::product(&qs[..n - 1]);
834 let pt_x = rng.gen_u128() % q_;
835 let pt_y = rng.gen_u128() % q_;
836
837 let res = Channel::with(std::io::empty(), |channel| {
838 let mut f = Dummy::new();
839 let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
840 let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
841 let z = f.pmr_geq(&crt_x, &crt_y, channel).unwrap();
842 Ok(f.output(&z, channel).unwrap().unwrap())
843 })
844 .unwrap();
845
846 let should_be = if pt_x >= pt_y { 1 } else { 0 };
847 assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
848 }
849 }
850
851 #[test]
852 #[ignore]
853 fn crt_div() {
854 let mut rng = rand::thread_rng();
855 for _ in 0..8 {
856 let qs = rng.gen_usable_factors();
857 let n = qs.len();
858 let q = crate::util::product(&qs);
859 let q_ = crate::util::product(&qs[..n - 1]);
860 let pt_x = rng.gen_u128() % q_;
861 let pt_y = rng.gen_u128() % q_;
862
863 let res = Channel::with(std::io::empty(), |channel| {
864 let mut f = Dummy::new();
865 let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
866 let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
867 let z = f.crt_div(&crt_x, &crt_y, channel).unwrap();
868 Ok(f.crt_output(&z, channel).unwrap().unwrap())
869 })
870 .unwrap();
871
872 let should_be = pt_x / pt_y;
873 assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
874 }
875 }
876}