fancy_garbling/
dummy.rs

1//! Dummy implementation of `Fancy`.
2//!
3//! Useful for evaluating the circuits produced by `Fancy` without actually
4//! creating any circuits.
5
6use swanky_channel::Channel;
7use swanky_error::ErrorKind;
8
9use crate::{
10    FancyArithmetic, FancyBinary, check_binary,
11    fancy::{Fancy, FancyInput, FancyReveal, HasModulus},
12};
13
14/// Simple struct that performs the fancy computation over `u16`.
15pub struct Dummy {}
16
17/// Wrapper around `u16`.
18#[derive(Clone, Debug)]
19pub struct DummyVal {
20    val: u16,
21    modulus: u16,
22}
23
24impl HasModulus for DummyVal {
25    fn modulus(&self) -> u16 {
26        self.modulus
27    }
28}
29
30impl DummyVal {
31    /// Create a new DummyVal.
32    pub fn new(val: u16, modulus: u16) -> Self {
33        Self { val, modulus }
34    }
35
36    /// Extract the value.
37    pub fn val(&self) -> u16 {
38        self.val
39    }
40}
41
42impl Dummy {
43    /// Create a new Dummy.
44    pub fn new() -> Dummy {
45        Dummy {}
46    }
47}
48
49impl Default for Dummy {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl FancyInput for Dummy {
56    type Item = DummyVal;
57
58    /// Encode a single dummy value.
59    fn encode(
60        &mut self,
61        value: u16,
62        modulus: u16,
63        _: &mut Channel,
64    ) -> swanky_error::Result<DummyVal> {
65        Ok(DummyVal::new(value, modulus))
66    }
67
68    /// Encode a slice of inputs and a slice of moduli as DummyVals.
69    fn encode_many(
70        &mut self,
71        xs: &[u16],
72        moduli: &[u16],
73        _: &mut Channel,
74    ) -> swanky_error::Result<Vec<DummyVal>> {
75        assert_eq!(xs.len(), moduli.len());
76        Ok(xs
77            .iter()
78            .zip(moduli.iter())
79            .map(|(x, q)| DummyVal::new(*x, *q))
80            .collect())
81    }
82
83    fn receive_many(
84        &mut self,
85        _moduli: &[u16],
86        _: &mut Channel,
87    ) -> swanky_error::Result<Vec<DummyVal>> {
88        // Receive is undefined for Dummy which is a single party "protocol"
89        swanky_error::bail!(
90            ErrorKind::UnsupportedError,
91            "`receive_many` is undefined for `Dummy`"
92        );
93    }
94}
95
96impl FancyBinary for Dummy {
97    fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
98        check_binary!(x);
99        check_binary!(y);
100
101        self.add(x, y)
102    }
103
104    fn and(
105        &mut self,
106        x: &Self::Item,
107        y: &Self::Item,
108        channel: &mut Channel,
109    ) -> swanky_error::Result<Self::Item> {
110        check_binary!(x);
111        check_binary!(y);
112
113        self.mul(x, y, channel)
114    }
115
116    fn negate(&mut self, x: &Self::Item) -> Self::Item {
117        check_binary!(x);
118
119        self.xor(x, &DummyVal::new(1, 2))
120    }
121}
122
123impl FancyArithmetic for Dummy {
124    fn add(&mut self, x: &DummyVal, y: &DummyVal) -> DummyVal {
125        assert_eq!(x.modulus(), y.modulus());
126        DummyVal {
127            val: (x.val + y.val) % x.modulus,
128            modulus: x.modulus,
129        }
130    }
131
132    fn sub(&mut self, x: &DummyVal, y: &DummyVal) -> DummyVal {
133        assert_eq!(x.modulus(), y.modulus());
134        DummyVal {
135            val: (x.modulus + x.val - y.val) % x.modulus,
136            modulus: x.modulus,
137        }
138    }
139
140    fn cmul(&mut self, x: &DummyVal, c: u16) -> DummyVal {
141        DummyVal {
142            val: (x.val * c) % x.modulus,
143            modulus: x.modulus,
144        }
145    }
146
147    fn mul(
148        &mut self,
149        x: &DummyVal,
150        y: &DummyVal,
151        _: &mut Channel,
152    ) -> swanky_error::Result<DummyVal> {
153        Ok(DummyVal {
154            val: x.val * y.val % x.modulus,
155            modulus: x.modulus,
156        })
157    }
158
159    fn proj(
160        &mut self,
161        x: &DummyVal,
162        modulus: u16,
163        tt: Option<Vec<u16>>,
164        _: &mut Channel,
165    ) -> swanky_error::Result<DummyVal> {
166        assert!(tt.is_some(), "`tt` must not be `None`");
167        let tt = tt.unwrap();
168        assert!(
169            tt.len() >= x.modulus() as usize,
170            "`tt` not large enough for `x`s modulus"
171        );
172        assert!(
173            tt.iter().all(|&x| x < modulus),
174            "`tt` value larger than `q`"
175        );
176        let val = tt[x.val as usize];
177        Ok(DummyVal { val, modulus })
178    }
179}
180
181impl Fancy for Dummy {
182    type Item = DummyVal;
183
184    fn constant(
185        &mut self,
186        val: u16,
187        modulus: u16,
188        _: &mut Channel,
189    ) -> swanky_error::Result<DummyVal> {
190        Ok(DummyVal { val, modulus })
191    }
192
193    fn output(&mut self, x: &DummyVal, _: &mut Channel) -> swanky_error::Result<Option<u16>> {
194        Ok(Some(x.val))
195    }
196}
197
198impl FancyReveal for Dummy {
199    fn reveal(&mut self, x: &DummyVal, _: &mut Channel) -> swanky_error::Result<u16> {
200        Ok(x.val)
201    }
202}
203
204#[cfg(test)]
205mod bundle {
206    use super::*;
207    use crate::{
208        fancy::{ArithmeticBundleGadgets, BinaryGadgets, Bundle, BundleGadgets, CrtGadgets},
209        util::{self, RngExt},
210    };
211    use itertools::Itertools;
212    use rand::thread_rng;
213
214    const NITERS: usize = 1 << 10;
215
216    #[test]
217    fn test_addition() {
218        let mut rng = thread_rng();
219        for _ in 0..NITERS {
220            let q = rng.gen_usable_composite_modulus();
221            let x = rng.gen_u128() % q;
222            let y = rng.gen_u128() % q;
223            let mut d = Dummy::new();
224            let out = Channel::with(std::io::empty(), |channel| {
225                let x = d.crt_encode(x, q, channel).unwrap();
226                let y = d.crt_encode(y, q, channel).unwrap();
227                let z = d.crt_add(&x, &y);
228                Ok(d.crt_output(&z, channel).unwrap().unwrap())
229            })
230            .unwrap();
231            assert_eq!(out, (x + y) % q);
232        }
233    }
234
235    #[test]
236    fn test_subtraction() {
237        let mut rng = thread_rng();
238        for _ in 0..NITERS {
239            let q = rng.gen_usable_composite_modulus();
240            let x = rng.gen_u128() % q;
241            let y = rng.gen_u128() % q;
242            let mut d = Dummy::new();
243            let out = Channel::with(std::io::empty(), |channel| {
244                let x = d.crt_encode(x, q, channel).unwrap();
245                let y = d.crt_encode(y, q, channel).unwrap();
246                let z = d.crt_sub(&x, &y);
247                Ok(d.crt_output(&z, channel).unwrap().unwrap())
248            })
249            .unwrap();
250            assert_eq!(out, (x + q - y) % q);
251        }
252    }
253
254    #[test]
255    fn test_binary_cmul() {
256        let mut rng = thread_rng();
257        for _ in 0..NITERS {
258            let nbits = 64;
259            let q = 1 << nbits;
260            let x = rng.gen_u128() % q;
261            let c = 1 + rng.gen_u128() % q;
262            let mut d = Dummy::new();
263            let out = Channel::with(std::io::empty(), |channel| {
264                let x = d.bin_encode(x, nbits, channel).unwrap();
265                let z = d.bin_cmul(&x, c, nbits, channel).unwrap();
266                Ok(d.bin_output(&z, channel).unwrap().unwrap())
267            })
268            .unwrap();
269            assert_eq!(out, (x * c) % q);
270        }
271    }
272
273    #[test]
274    fn test_binary_multiplication() {
275        let mut rng = thread_rng();
276        for _ in 0..NITERS {
277            let nbits = 64;
278            let q = 1 << nbits;
279            let x = rng.gen_u128() % q;
280            let y = rng.gen_u128() % q;
281            let mut d = Dummy::new();
282            let out = Channel::with(std::io::empty(), |channel| {
283                let x = d.bin_encode(x, nbits, channel).unwrap();
284                let y = d.bin_encode(y, nbits, channel).unwrap();
285                let z = d.bin_multiplication_lower_half(&x, &y, channel).unwrap();
286                let out = d.bin_output(&z, channel).unwrap().unwrap();
287                Ok(out)
288            })
289            .unwrap();
290            assert_eq!(out, (x * y) % q);
291        }
292    }
293
294    #[test]
295    fn test_shift_extend() {
296        let mut rng = thread_rng();
297        for _ in 0..NITERS {
298            let nbits = 64;
299            let q = 1 << nbits;
300            let shift_size = rng.gen_usize() % nbits;
301            let x = rng.gen_u128() % q;
302            let mut d = Dummy::new();
303            let out = Channel::with(std::io::empty(), |channel| {
304                use crate::BinaryBundle;
305                let x = d.bin_encode(x, nbits, channel).unwrap();
306                let z = d.shift_extend(&x, shift_size, channel).unwrap();
307                Ok(d.bin_output(&BinaryBundle::from(z), channel)
308                    .unwrap()
309                    .unwrap())
310            })
311            .unwrap();
312            assert_eq!(out, x << shift_size);
313        }
314    }
315
316    #[test]
317    fn test_binary_full_multiplication() {
318        let mut rng = thread_rng();
319        for _ in 0..NITERS {
320            let nbits = 64;
321            let q = 1 << nbits;
322            let x = rng.gen_u128() % q;
323            let y = rng.gen_u128() % q;
324            let mut d = Dummy::new();
325            let out = Channel::with(std::io::empty(), |channel| {
326                let x = d.bin_encode(x, nbits, channel).unwrap();
327                let y = d.bin_encode(y, nbits, channel).unwrap();
328                let z = d.bin_mul(&x, &y, channel).unwrap();
329                println!("z.len() = {}", z.size());
330                Ok(d.bin_output(&z, channel).unwrap().unwrap())
331            })
332            .unwrap();
333            assert_eq!(out, x * y);
334        }
335    }
336
337    #[test]
338    fn test_binary_division() {
339        let mut rng = thread_rng();
340        for _ in 0..NITERS {
341            let nbits = 64;
342            let q = 1 << nbits;
343            let x = rng.gen_u128() % q;
344            let y = rng.gen_u128() % q;
345            let mut d = Dummy::new();
346            let out = Channel::with(std::io::empty(), |channel| {
347                let x = d.bin_encode(x, nbits, channel).unwrap();
348                let y = d.bin_encode(y, nbits, channel).unwrap();
349                let z = d.bin_div(&x, &y, channel).unwrap();
350                Ok(d.bin_output(&z, channel).unwrap().unwrap())
351            })
352            .unwrap();
353            assert_eq!(out, x / y);
354        }
355    }
356
357    #[test]
358    fn max() {
359        let mut rng = thread_rng();
360        let q = util::modulus_with_width(10);
361        let n = 10;
362        for _ in 0..NITERS {
363            let inps = (0..n).map(|_| rng.gen_u128() % (q / 2)).collect_vec();
364            let should_be = *inps.iter().max().unwrap();
365            let mut d = Dummy::new();
366            let out = Channel::with(std::io::empty(), |channel| {
367                let xs = inps
368                    .into_iter()
369                    .map(|x| d.crt_encode(x, q, channel).unwrap())
370                    .collect_vec();
371                let z = d.crt_max(&xs, "100%", channel).unwrap();
372                Ok(d.crt_output(&z, channel).unwrap().unwrap())
373            })
374            .unwrap();
375            assert_eq!(out, should_be);
376        }
377    }
378
379    #[test]
380    fn twos_complement() {
381        let mut rng = thread_rng();
382        let nbits = 16;
383        let q = 1 << nbits;
384        for _ in 0..NITERS {
385            let x = rng.gen_u128() % q;
386            let should_be = (((!x) % q) + 1) % q;
387            let mut d = Dummy::new();
388            let out = Channel::with(std::io::empty(), |channel| {
389                let x = d.bin_encode(x, nbits, channel).unwrap();
390                let y = d.bin_twos_complement(&x, channel).unwrap();
391                Ok(d.bin_output(&y, channel).unwrap().unwrap())
392            })
393            .unwrap();
394            assert_eq!(out, should_be, "x={} y={} should_be={}", x, out, should_be);
395        }
396    }
397
398    #[test]
399    fn binary_addition() {
400        let mut rng = thread_rng();
401        let nbits = 16;
402        let q = 1 << nbits;
403        for _ in 0..NITERS {
404            let x = rng.gen_u128() % q;
405            let y = rng.gen_u128() % q;
406            let should_be = (x + y) % q;
407            let mut d = Dummy::new();
408            let (out, overflow) = Channel::with(std::io::empty(), |channel| {
409                let x = d.bin_encode(x, nbits, channel).unwrap();
410                let y = d.bin_encode(y, nbits, channel).unwrap();
411                let (z, _overflow) = d.bin_addition(&x, &y, channel).unwrap();
412                let overflow = d.output(&_overflow, channel).unwrap().unwrap();
413                let out = d.bin_output(&z, channel).unwrap().unwrap();
414                Ok((out, overflow))
415            })
416            .unwrap();
417            assert_eq!(out, should_be);
418            assert_eq!(overflow > 0, x + y >= q);
419        }
420    }
421
422    #[test]
423    fn binary_subtraction() {
424        let mut rng = thread_rng();
425        let nbits = 16;
426        let q = 1 << nbits;
427        for _ in 0..NITERS {
428            let x = rng.gen_u128() % q;
429            let y = rng.gen_u128() % q;
430            let (should_be, _) = x.overflowing_sub(y);
431            let should_be = should_be % q;
432            let mut d = Dummy::new();
433            let (out, overflow) = Channel::with(std::io::empty(), |channel| {
434                let x = d.bin_encode(x, nbits, channel).unwrap();
435                let y = d.bin_encode(y, nbits, channel).unwrap();
436                let (z, _overflow) = d.bin_subtraction(&x, &y, channel).unwrap();
437                let overflow = d.output(&_overflow, channel).unwrap().unwrap();
438                let out = d.bin_output(&z, channel).unwrap().unwrap();
439                Ok((out, overflow))
440            })
441            .unwrap();
442            assert_eq!(out, should_be);
443            assert_eq!(overflow > 0, (y != 0 && x >= y), "x={} y={}", x, y);
444        }
445    }
446
447    #[test]
448    fn binary_lt() {
449        let mut rng = thread_rng();
450        let nbits = 16;
451        let q = 1 << nbits;
452        for _ in 0..NITERS {
453            let x = rng.gen_u128() % q;
454            let y = rng.gen_u128() % q;
455            let should_be = x < y;
456            let mut d = Dummy::new();
457            let out = Channel::with(std::io::empty(), |channel| {
458                let x = d.bin_encode(x, nbits, channel).unwrap();
459                let y = d.bin_encode(y, nbits, channel).unwrap();
460                let z = d.bin_lt(&x, &y, channel).unwrap();
461                Ok(d.output(&z, channel).unwrap().unwrap())
462            })
463            .unwrap();
464            assert_eq!(out > 0, should_be, "x={} y={}", x, y);
465        }
466    }
467
468    #[test]
469    fn binary_lt_signed() {
470        let mut rng = thread_rng();
471        let nbits = 16;
472        let q = 1 << nbits;
473        for _ in 0..NITERS {
474            let x = rng.gen_u128() % q;
475            let y = rng.gen_u128() % q;
476            let should_be = (x as i16) < (y as i16);
477            let mut d = Dummy::new();
478            let out = Channel::with(std::io::empty(), |channel| {
479                let x = d.bin_encode(x, nbits, channel).unwrap();
480                let y = d.bin_encode(y, nbits, channel).unwrap();
481                let z = d.bin_lt_signed(&x, &y, channel).unwrap();
482                Ok(d.output(&z, channel).unwrap().unwrap())
483            })
484            .unwrap();
485            assert_eq!(out > 0, should_be, "x={} y={}", x as i16, y as i16);
486        }
487    }
488
489    #[test]
490    fn binary_max() {
491        let mut rng = thread_rng();
492        let n = 10;
493        let nbits = 16;
494        let q = 1 << nbits;
495        for _ in 0..NITERS {
496            let inps = (0..n).map(|_| rng.gen_u128() % q).collect_vec();
497            let should_be = *inps.iter().max().unwrap();
498            let mut d = Dummy::new();
499            let out = Channel::with(std::io::empty(), |channel| {
500                let xs = inps
501                    .into_iter()
502                    .map(|x| d.bin_encode(x, nbits, channel).unwrap())
503                    .collect_vec();
504                let z = d.bin_max(&xs, channel).unwrap();
505                Ok(d.bin_output(&z, channel).unwrap().unwrap())
506            })
507            .unwrap();
508            assert_eq!(out, should_be);
509        }
510    }
511
512    #[test] // bundle relu
513    fn test_relu() {
514        let mut rng = thread_rng();
515        for _ in 0..NITERS {
516            let q = crate::util::modulus_with_nprimes(4 + rng.gen_usize() % 7); // exact relu supports up to 11 primes
517            let x = rng.gen_u128() % q;
518            let mut d = Dummy::new();
519            let out = Channel::with(std::io::empty(), |channel| {
520                let x = d.crt_encode(x, q, channel).unwrap();
521                let z = d.crt_relu(&x, "100%", None, channel).unwrap();
522                Ok(d.crt_output(&z, channel).unwrap().unwrap())
523            })
524            .unwrap();
525            if x >= q / 2 {
526                assert_eq!(out, 0);
527            } else {
528                assert_eq!(out, x);
529            }
530        }
531    }
532
533    #[test]
534    fn test_mask() {
535        let mut rng = thread_rng();
536        for _ in 0..NITERS {
537            let q = crate::util::modulus_with_nprimes(4 + rng.gen_usize() % 7);
538            let x = rng.gen_u128() % q;
539            let b = rng.gen_bool();
540            let mut d = Dummy::new();
541            let out = Channel::with(std::io::empty(), |channel| {
542                let b = d.encode(b as u16, 2, channel).unwrap();
543                let x = d.crt_encode(x, q, channel).unwrap();
544                let z = d.mask(&b, &x, channel).unwrap().into();
545                Ok(d.crt_output(&z, channel).unwrap().unwrap())
546            })
547            .unwrap();
548            assert!(
549                if b { out == x } else { out == 0 },
550                "b={} x={} z={}",
551                b,
552                x,
553                out
554            );
555        }
556    }
557
558    #[test]
559    fn binary_abs() {
560        let mut rng = thread_rng();
561        for _ in 0..NITERS {
562            let nbits = 64;
563            let q = 1 << nbits;
564            let x = rng.gen_u128() % q;
565            let mut d = Dummy::new();
566            let out = Channel::with(std::io::empty(), |channel| {
567                let x = d.bin_encode(x, nbits, channel).unwrap();
568                let z = d.bin_abs(&x, channel).unwrap();
569                Ok(d.bin_output(&z, channel).unwrap().unwrap())
570            })
571            .unwrap();
572            let should_be = if x >> (nbits - 1) > 0 {
573                ((!x) + 1) & ((1 << nbits) - 1)
574            } else {
575                x
576            };
577            assert_eq!(out, should_be);
578        }
579    }
580
581    #[test]
582    fn binary_demux() {
583        let mut rng = thread_rng();
584        for _ in 0..NITERS {
585            let nbits = 8;
586            let q = 1 << nbits;
587            let x = rng.gen_u128() % q;
588            let mut d = Dummy::new();
589            let outs = Channel::with(std::io::empty(), |channel| {
590                let x = d.bin_encode(x, nbits, channel).unwrap();
591                let zs = d.bin_demux(&x, channel).unwrap();
592                Ok(d.outputs(&zs, channel).unwrap().unwrap())
593            })
594            .unwrap();
595            for (i, z) in outs.into_iter().enumerate() {
596                if i as u128 == x {
597                    assert_eq!(z, 1);
598                } else {
599                    assert_eq!(z, 0);
600                }
601            }
602        }
603    }
604
605    #[test]
606    fn binary_eq() {
607        let mut rng = thread_rng();
608        for _ in 0..NITERS {
609            let nbits = rng.gen_usize() % 100 + 2;
610            let q = 1 << nbits;
611            let x = rng.gen_u128() % q;
612            let y = if rng.gen_bool() {
613                x
614            } else {
615                rng.gen_u128() % q
616            };
617            let mut d = Dummy::new();
618            let out = Channel::with(std::io::empty(), |channel| {
619                let x = d.bin_encode(x, nbits, channel).unwrap();
620                let y = d.bin_encode(y, nbits, channel).unwrap();
621                let z = d.bin_eq_bundles(&x, &y, channel).unwrap();
622                Ok(d.output(&z, channel).unwrap().unwrap())
623            })
624            .unwrap();
625            assert_eq!(out, (x == y) as u16);
626        }
627    }
628
629    #[test]
630    fn binary_proj_eq() {
631        let mut rng = thread_rng();
632        for _ in 0..NITERS {
633            let nbits = rng.gen_usize() % 100 + 2;
634            let q = 1 << nbits;
635            let x = rng.gen_u128() % q;
636            let y = if rng.gen_bool() {
637                x
638            } else {
639                rng.gen_u128() % q
640            };
641            let mut d = Dummy::new();
642            let out = Channel::with(std::io::empty(), |channel| {
643                let x = d.bin_encode(x, nbits, channel).unwrap();
644                let y = d.bin_encode(y, nbits, channel).unwrap();
645                let z = d.eq_bundles(&x, &y, channel).unwrap();
646                Ok(d.output(&z, channel).unwrap().unwrap())
647            })
648            .unwrap();
649            assert_eq!(out, (x == y) as u16);
650        }
651    }
652
653    #[test]
654    fn binary_rsa() {
655        let mut rng = thread_rng();
656        for _ in 0..NITERS {
657            let nbits = 64;
658            let q = 1 << nbits;
659            let x = rng.gen_u128() % q;
660            let shift_size = rng.gen_usize() % nbits;
661            let mut d = Dummy::new();
662            let out = Channel::with(std::io::empty(), |channel| {
663                let x = d.bin_encode(x, nbits, channel).unwrap();
664                let z = d.bin_rsa(&x, shift_size);
665                Ok(d.bin_output(&z, channel).unwrap().unwrap() as i64)
666            })
667            .unwrap();
668            let should_be = (x as i64) >> shift_size;
669            assert_eq!(out, should_be);
670        }
671    }
672
673    #[test]
674    fn binary_rsl() {
675        let mut rng = thread_rng();
676        for _ in 0..NITERS {
677            let nbits = 64;
678            let q = 1 << nbits;
679            let x = rng.gen_u128() % q;
680            let shift_size = rng.gen_usize() % nbits;
681            let mut d = Dummy::new();
682            let out = Channel::with(std::io::empty(), |channel| {
683                let x = d.bin_encode(x, nbits, channel).unwrap();
684                let z = d.bin_rsl(&x, shift_size, channel).unwrap();
685                Ok(d.bin_output(&z, channel).unwrap().unwrap())
686            })
687            .unwrap();
688            let should_be = x >> shift_size;
689            assert_eq!(out, should_be);
690        }
691    }
692
693    #[test]
694    fn test_mixed_radix_addition_msb_only() {
695        let mut rng = thread_rng();
696        for _ in 0..NITERS {
697            let nargs = 2 + rng.gen_usize() % 10;
698            let mods = (0..7).map(|_| rng.gen_modulus()).collect_vec();
699            let Q: u128 = util::product(&mods);
700
701            println!("nargs={} mods={:?} Q={}", nargs, mods, Q);
702
703            // test maximum overflow
704            let xs = (0..nargs)
705                .map(|_| {
706                    Bundle::new(
707                        util::as_mixed_radix(Q - 1, &mods)
708                            .into_iter()
709                            .zip(&mods)
710                            .map(|(x, q)| DummyVal::new(x, *q))
711                            .collect_vec(),
712                    )
713                })
714                .collect_vec();
715
716            let mut d = Dummy::new();
717
718            let res = Channel::with(std::io::empty(), |channel| {
719                let z = d.mixed_radix_addition_msb_only(&xs, channel).unwrap();
720                Ok(d.output(&z, channel).unwrap().unwrap())
721            })
722            .unwrap();
723
724            let should_be = *util::as_mixed_radix((Q - 1) * (nargs as u128) % Q, &mods)
725                .last()
726                .unwrap();
727            assert_eq!(res, should_be);
728
729            // test random values
730            for _ in 0..4 {
731                let mut sum = 0;
732
733                let xs = (0..nargs)
734                    .map(|_| {
735                        let x = rng.gen_u128() % Q;
736                        sum = (sum + x) % Q;
737                        Bundle::new(
738                            util::as_mixed_radix(x, &mods)
739                                .into_iter()
740                                .zip(&mods)
741                                .map(|(x, q)| DummyVal::new(x, *q))
742                                .collect_vec(),
743                        )
744                    })
745                    .collect_vec();
746
747                let mut d = Dummy::new();
748                let res = Channel::with(std::io::empty(), |channel| {
749                    let z = d.mixed_radix_addition_msb_only(&xs, channel).unwrap();
750                    Ok(d.output(&z, channel).unwrap().unwrap())
751                })
752                .unwrap();
753
754                let should_be = *util::as_mixed_radix(sum, &mods).last().unwrap();
755                assert_eq!(res, should_be);
756            }
757        }
758    }
759}
760
761#[cfg(test)]
762mod pmr_tests {
763    use super::*;
764    use crate::{
765        fancy::{BundleGadgets, CrtGadgets, FancyInput},
766        util::RngExt,
767    };
768
769    #[test]
770    fn pmr() {
771        let mut rng = rand::thread_rng();
772        for _ in 0..8 {
773            let ps = rng.gen_usable_factors();
774            let q = crate::util::product(&ps);
775            let pt = rng.gen_u128() % q;
776
777            let res = Channel::with(std::io::empty(), |channel| {
778                let mut f = Dummy::new();
779                let x = f.crt_encode(pt, q, channel).unwrap();
780                let z = f.crt_to_pmr(&x, channel).unwrap();
781                Ok(f.output_bundle(&z, channel).unwrap().unwrap())
782            })
783            .unwrap();
784
785            let should_be = to_pmr_pt(pt, &ps);
786            assert_eq!(res, should_be);
787        }
788    }
789
790    fn to_pmr_pt(x: u128, ps: &[u16]) -> Vec<u16> {
791        let mut ds = vec![0; ps.len()];
792        let mut q = 1;
793        for i in 0..ps.len() {
794            let p = ps[i] as u128;
795            ds[i] = ((x / q) % p) as u16;
796            q *= p;
797        }
798        ds
799    }
800
801    #[test]
802    fn pmr_lt() {
803        let mut rng = rand::thread_rng();
804        for _ in 0..8 {
805            let qs = rng.gen_usable_factors();
806            let n = qs.len();
807            let q = crate::util::product(&qs);
808            let q_ = crate::util::product(&qs[..n - 1]);
809            let pt_x = rng.gen_u128() % q_;
810            let pt_y = rng.gen_u128() % q_;
811
812            let res = Channel::with(std::io::empty(), |channel| {
813                let mut f = Dummy::new();
814                let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
815                let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
816                let z = f.pmr_lt(&crt_x, &crt_y, channel).unwrap();
817                Ok(f.output(&z, channel).unwrap().unwrap())
818            })
819            .unwrap();
820
821            let should_be = if pt_x < pt_y { 1 } else { 0 };
822            assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
823        }
824    }
825
826    #[test]
827    fn pmr_geq() {
828        let mut rng = rand::thread_rng();
829        for _ in 0..8 {
830            let qs = rng.gen_usable_factors();
831            let n = qs.len();
832            let q = crate::util::product(&qs);
833            let q_ = crate::util::product(&qs[..n - 1]);
834            let pt_x = rng.gen_u128() % q_;
835            let pt_y = rng.gen_u128() % q_;
836
837            let res = Channel::with(std::io::empty(), |channel| {
838                let mut f = Dummy::new();
839                let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
840                let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
841                let z = f.pmr_geq(&crt_x, &crt_y, channel).unwrap();
842                Ok(f.output(&z, channel).unwrap().unwrap())
843            })
844            .unwrap();
845
846            let should_be = if pt_x >= pt_y { 1 } else { 0 };
847            assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
848        }
849    }
850
851    #[test]
852    #[ignore]
853    fn crt_div() {
854        let mut rng = rand::thread_rng();
855        for _ in 0..8 {
856            let qs = rng.gen_usable_factors();
857            let n = qs.len();
858            let q = crate::util::product(&qs);
859            let q_ = crate::util::product(&qs[..n - 1]);
860            let pt_x = rng.gen_u128() % q_;
861            let pt_y = rng.gen_u128() % q_;
862
863            let res = Channel::with(std::io::empty(), |channel| {
864                let mut f = Dummy::new();
865                let crt_x = f.crt_encode(pt_x, q, channel).unwrap();
866                let crt_y = f.crt_encode(pt_y, q, channel).unwrap();
867                let z = f.crt_div(&crt_x, &crt_y, channel).unwrap();
868                Ok(f.crt_output(&z, channel).unwrap().unwrap())
869            })
870            .unwrap();
871
872            let should_be = pt_x / pt_y;
873            assert_eq!(res, should_be, "q={}, x={}, y={}", q, pt_x, pt_y);
874        }
875    }
876}