fancy_garbling/
circuit.rs

1//! DSL for creating circuits compatible with fancy-garbling in the old-fashioned way,
2//! where you create a circuit for a computation then garble it.
3
4use crate::{
5    dummy::{Dummy, DummyVal},
6    fancy::{BinaryBundle, CrtBundle, Fancy, FancyInput, HasModulus},
7    informer::Informer,
8};
9use itertools::Itertools;
10use std::{collections::HashMap, fmt::Display};
11use swanky_channel::Channel;
12
13mod binary;
14pub use binary::{BinaryCircuit, BinaryGate};
15mod arithmetic;
16pub use arithmetic::{ArithmeticCircuit, ArithmeticGate};
17
18/// The index and modulus of a gate in a circuit.
19#[derive(Clone, Copy, Debug, PartialEq)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct CircuitRef {
22    pub(crate) ix: usize,
23    pub(crate) modulus: u16,
24}
25
26impl std::fmt::Display for CircuitRef {
27    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28        write!(f, "[{} | {}]", self.ix, self.modulus)
29    }
30}
31
32impl HasModulus for CircuitRef {
33    fn modulus(&self) -> u16 {
34        self.modulus
35    }
36}
37
38/// Trait to display circuit evaluation costs
39///
40/// Blanket implementation available for all circuits
41/// that can be evaluated with an `Informer`
42pub trait CircuitInfo {
43    /// Print circuit info
44    fn print_info(&self) -> swanky_error::Result<()>;
45}
46
47impl<C: EvaluableCircuit<Informer<Dummy>>> CircuitInfo for C {
48    fn print_info(&self) -> swanky_error::Result<()> {
49        let mut informer = crate::informer::Informer::new(Dummy::new());
50
51        // encode inputs as InformerVals
52        let gb = Channel::with(std::io::empty(), |channel| {
53            self.get_garbler_input_refs()
54                .iter()
55                .map(|r| informer.encode(0, r.modulus(), channel))
56                .collect::<swanky_error::Result<Vec<DummyVal>>>()
57        })?;
58        let ev = Channel::with(std::io::empty(), |channel| {
59            self.get_evaluator_input_refs()
60                .iter()
61                .map(|r| informer.encode(0, r.modulus(), channel))
62                .collect::<swanky_error::Result<Vec<DummyVal>>>()
63        })?;
64
65        Channel::with(std::io::empty(), |c| self.eval(&mut informer, &gb, &ev, c))?;
66        println!("{}", informer.stats());
67        Ok(())
68    }
69}
70
71/// A Circuit that can be evaluated by a given Fancy object
72///
73/// Supertrait ensures that circuit can be built by `CircuitBuilder`
74pub trait EvaluableCircuit<F: Fancy>: CircuitType {
75    /// Evaluate the circuit.
76    ///
77    /// The argument `f` provides the [`Fancy`] instantiation to use during
78    /// evaluation, and the actual circuit data is accessed using `channel`. The
79    /// output is a vector of `Option<u16>`s because certain [`Fancy`]
80    /// instantiations may not output anything (for example, a garbler doesn't
81    /// have an "output", so `None` would be returned here).
82    fn eval(
83        &self,
84        f: &mut F,
85        garbler_inputs: &[F::Item],
86        evaluator_inputs: &[F::Item],
87        channel: &mut Channel,
88    ) -> swanky_error::Result<Option<Vec<u16>>> {
89        let wirelabels = self.eval_to_wirelabels(f, garbler_inputs, evaluator_inputs, channel)?;
90        f.outputs(&wirelabels, channel)
91    }
92
93    /// Evaluate the circuit up to producing the output [`Fancy::Item`]s, and
94    /// output those.
95    ///
96    /// The argument `f` provides the [`Fancy`] instantiation to use during
97    /// evaluation, and the actual circuit data is accessed using `channel`. The
98    /// output is a vector of [`Fancy::Item`]s that correspond to the outputs of
99    /// the circuit evaluation.
100    fn eval_to_wirelabels(
101        &self,
102        f: &mut F,
103        garbler_inputs: &[F::Item],
104        evaluator_inputs: &[F::Item],
105        channel: &mut Channel,
106    ) -> swanky_error::Result<Vec<F::Item>>;
107}
108
109/// Trait representing circuit gates that can be used in `CircuitType`
110pub trait GateType: Display {
111    /// Generate constant gate
112    fn make_constant(val: u16) -> Self;
113
114    /// Generate garbler input gate
115    fn make_garbler_input(id: usize) -> Self;
116
117    /// Generate evaluator input gate
118    fn make_evaluator_input(id: usize) -> Self;
119}
120
121impl GateType for ArithmeticGate {
122    fn make_constant(val: u16) -> Self {
123        Self::Constant { val }
124    }
125
126    fn make_garbler_input(id: usize) -> Self {
127        Self::GarblerInput { id }
128    }
129
130    fn make_evaluator_input(id: usize) -> Self {
131        Self::EvaluatorInput { id }
132    }
133}
134
135/// Trait representing circuits that can be built by `CircuitBuilder`
136pub trait CircuitType {
137    /// Gates that the circuit is composed of
138    type Gate: GateType;
139
140    /// Increase number of nonfree gates
141    fn increment_nonfree_gates(&mut self);
142
143    /// Make a new `Circuit` object.
144    fn new(ngates: Option<usize>) -> Self;
145
146    /// Get all output refs
147    fn get_output_refs(&self) -> &[CircuitRef];
148
149    /// Get all garbler input refs
150    fn get_garbler_input_refs(&self) -> &[CircuitRef];
151
152    /// Get all evaluator input refs
153    fn get_evaluator_input_refs(&self) -> &[CircuitRef];
154
155    /// Get number of nonfree gates
156    fn get_num_nonfree_gates(&self) -> usize;
157
158    /// Add a gate
159    fn push_gates(&mut self, gate: Self::Gate);
160
161    /// Add a constant ref
162    fn push_const_ref(&mut self, xref: CircuitRef);
163
164    /// Add an output ref
165    fn push_output_ref(&mut self, xref: CircuitRef);
166
167    /// Add a garbler input ref
168    fn push_garbler_input_ref(&mut self, xref: CircuitRef);
169
170    /// Add an evaluator input ref
171    fn push_evaluator_input_ref(&mut self, xref: CircuitRef);
172
173    /// Add wire moulus
174    fn push_modulus(&mut self, modulus: u16);
175
176    /// Return the modulus of the garbler input indexed by `i`.
177    fn garbler_input_mod(&self, i: usize) -> u16;
178
179    /// Return the modulus of the evaluator input indexed by `i`.
180    fn evaluator_input_mod(&self, i: usize) -> u16;
181
182    /// Return the number of garbler inputs.
183    #[inline]
184    fn num_garbler_inputs(&self) -> usize {
185        self.get_garbler_input_refs().len()
186    }
187
188    /// Return the number of evaluator inputs.
189    #[inline]
190    fn num_evaluator_inputs(&self) -> usize {
191        self.get_evaluator_input_refs().len()
192    }
193
194    /// Return the number of outputs.
195    #[inline]
196    fn noutputs(&self) -> usize {
197        self.get_output_refs().len()
198    }
199}
200
201/// Evaluate the circuit in plaintext.
202///
203/// # Panics
204/// Panics if either `garbler_inputs.len()` or `evaluator_inputs.len()` does not
205/// equal the circuit's expected number of inputs.
206pub fn eval_plain<C: EvaluableCircuit<Dummy>>(
207    circuit: &C,
208    garbler_inputs: &[u16],
209    evaluator_inputs: &[u16],
210) -> swanky_error::Result<Vec<u16>> {
211    assert_eq!(garbler_inputs.len(), circuit.num_garbler_inputs());
212    assert_eq!(evaluator_inputs.len(), circuit.num_evaluator_inputs());
213
214    let mut dummy = crate::dummy::Dummy::new();
215
216    // encode inputs as DummyVals
217    let gb = garbler_inputs
218        .iter()
219        .zip(circuit.get_garbler_input_refs().iter())
220        .map(|(x, r)| DummyVal::new(*x, r.modulus()))
221        .collect_vec();
222    let ev = evaluator_inputs
223        .iter()
224        .zip(circuit.get_evaluator_input_refs().iter())
225        .map(|(x, r)| DummyVal::new(*x, r.modulus()))
226        .collect_vec();
227
228    // XXX `unwrap` is used!
229    let outputs = Channel::with(std::io::empty(), |c| {
230        // XXX `unwrap` is used!
231        Ok(circuit.eval(&mut dummy, &gb, &ev, c).unwrap())
232    })
233    .unwrap();
234    Ok(outputs.expect("dummy will always return Some(u16) output"))
235}
236
237/// CircuitBuilder is used to build circuits.
238pub struct CircuitBuilder<Circuit> {
239    next_ref_ix: usize,
240    next_garbler_input_id: usize,
241    next_evaluator_input_id: usize,
242    const_map: HashMap<(u16, u16), CircuitRef>,
243    circ: Circuit,
244}
245
246impl<Circuit: CircuitType> Fancy for CircuitBuilder<Circuit> {
247    type Item = CircuitRef;
248
249    fn constant(
250        &mut self,
251        val: u16,
252        modulus: u16,
253        _: &mut Channel,
254    ) -> swanky_error::Result<CircuitRef> {
255        Ok(self.lookup_constant(val, modulus))
256    }
257
258    fn output(&mut self, xref: &CircuitRef, _: &mut Channel) -> swanky_error::Result<Option<u16>> {
259        self.circ.push_output_ref(*xref);
260        Ok(None)
261    }
262}
263
264impl<Circuit: CircuitType> CircuitBuilder<Circuit> {
265    /// Make a new `CircuitBuilder`.
266    pub fn new() -> Self {
267        CircuitBuilder {
268            next_ref_ix: 0,
269            next_garbler_input_id: 0,
270            next_evaluator_input_id: 0,
271            const_map: HashMap::new(),
272            circ: Circuit::new(None),
273        }
274    }
275
276    /// Finish circuit building, outputting the resulting circuit.
277    pub fn finish(self) -> Circuit {
278        self.circ
279    }
280
281    /// Look up a constant in the internal constant map, or add it if no such
282    /// constant exists.
283    fn lookup_constant(&mut self, val: u16, modulus: u16) -> CircuitRef {
284        match self.const_map.get(&(val, modulus)) {
285            Some(&r) => r,
286            None => {
287                let gate = Circuit::Gate::make_constant(val);
288                let r = self.gate(gate, modulus);
289                self.const_map.insert((val, modulus), r);
290                self.circ.push_const_ref(r);
291                r
292            }
293        }
294    }
295
296    fn get_next_garbler_input_id(&mut self) -> usize {
297        let current = self.next_garbler_input_id;
298        self.next_garbler_input_id += 1;
299        current
300    }
301
302    fn get_next_evaluator_input_id(&mut self) -> usize {
303        let current = self.next_evaluator_input_id;
304        self.next_evaluator_input_id += 1;
305        current
306    }
307
308    fn get_next_ciphertext_id(&mut self) -> usize {
309        let current = self.circ.get_num_nonfree_gates();
310        self.circ.increment_nonfree_gates();
311        current
312    }
313
314    fn get_next_ref_ix(&mut self) -> usize {
315        let current = self.next_ref_ix;
316        self.next_ref_ix += 1;
317        current
318    }
319
320    fn gate(&mut self, gate: Circuit::Gate, modulus: u16) -> CircuitRef {
321        self.circ.push_gates(gate);
322        self.circ.push_modulus(modulus);
323        let ix = self.get_next_ref_ix();
324        CircuitRef { ix, modulus }
325    }
326
327    /// Get CircuitRef for a garbler input wire.
328    pub fn garbler_input(&mut self, modulus: u16) -> CircuitRef {
329        let id = self.get_next_garbler_input_id();
330        let r = self.gate(Circuit::Gate::make_garbler_input(id), modulus);
331        self.circ.push_garbler_input_ref(r);
332        r
333    }
334
335    /// Get CircuitRef for an evaluator input wire.
336    pub fn evaluator_input(&mut self, modulus: u16) -> CircuitRef {
337        let id = self.get_next_evaluator_input_id();
338        let r = self.gate(Circuit::Gate::make_evaluator_input(id), modulus);
339        self.circ.push_evaluator_input_ref(r);
340        r
341    }
342
343    /// Get a vec of CircuitRefs for garbler inputs.
344    pub fn garbler_inputs(&mut self, mods: &[u16]) -> Vec<CircuitRef> {
345        mods.iter().map(|q| self.garbler_input(*q)).collect()
346    }
347
348    /// Get a vec of CircuitRefs for garbler inputs.
349    pub fn evaluator_inputs(&mut self, mods: &[u16]) -> Vec<CircuitRef> {
350        mods.iter().map(|q| self.evaluator_input(*q)).collect()
351    }
352
353    /// Get a CrtBundle for the garbler using composite modulus Q
354    pub fn crt_garbler_input(&mut self, modulus: u128) -> CrtBundle<CircuitRef> {
355        CrtBundle::new(self.garbler_inputs(&crate::util::factor(modulus)))
356    }
357
358    /// Get a CrtBundle for the evaluator using composite modulus Q
359    pub fn crt_evaluator_input(&mut self, modulus: u128) -> CrtBundle<CircuitRef> {
360        CrtBundle::new(self.evaluator_inputs(&crate::util::factor(modulus)))
361    }
362
363    /// Get a BinaryBundle for the garbler with n bits.
364    pub fn bin_garbler_input(&mut self, nbits: usize) -> BinaryBundle<CircuitRef> {
365        BinaryBundle::new(self.garbler_inputs(&vec![2; nbits]))
366    }
367
368    /// Get a BinaryBundle for the evaluator with n bits.
369    pub fn bin_evaluator_input(&mut self, nbits: usize) -> BinaryBundle<CircuitRef> {
370        BinaryBundle::new(self.evaluator_inputs(&vec![2; nbits]))
371    }
372}
373
374impl<Circuit: CircuitType> Default for CircuitBuilder<Circuit> {
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380#[cfg(test)]
381mod plaintext {
382    use super::*;
383    use crate::{FancyArithmetic, FancyBinary, util::RngExt};
384    use itertools::Itertools;
385    use rand::thread_rng;
386
387    #[test] // {{{ and_gate_fan_n
388    fn and_gate_fan_n() {
389        let mut rng = thread_rng();
390        let n = 2 + (rng.gen_usize() % 200);
391
392        let c = Channel::with(std::io::empty(), |channel| {
393            let mut b = CircuitBuilder::<BinaryCircuit>::new();
394            let inps = b.evaluator_inputs(&vec![2; n]);
395            let z = b.and_many(&inps, channel).unwrap();
396            b.output(&z, channel).unwrap();
397            let c = b.finish();
398            Ok(c)
399        })
400        .unwrap();
401
402        for _ in 0..16 {
403            let mut inps: Vec<u16> = Vec::new();
404            for _ in 0..n {
405                inps.push(rng.gen_bool() as u16);
406            }
407            let res = inps.iter().fold(1, |acc, &x| x & acc);
408            let out = eval_plain(&c, &[], &inps).unwrap()[0];
409            if out != res {
410                println!("{:?} {} {}", inps, out, res);
411                panic!("incorrect output n={}", n);
412            }
413        }
414    }
415    //}}}
416    #[test] // {{{ or_gate_fan_n
417    fn or_gate_fan_n() {
418        let mut rng = thread_rng();
419        let n = 2 + (rng.gen_usize() % 200);
420        let c = Channel::with(std::io::empty(), |channel| {
421            let mut b: CircuitBuilder<BinaryCircuit> = CircuitBuilder::new();
422            let inps = b.evaluator_inputs(&vec![2; n]);
423            let z = b.or_many(&inps, channel).unwrap();
424            b.output(&z, channel).unwrap();
425            let c = b.finish();
426            Ok(c)
427        })
428        .unwrap();
429
430        for _ in 0..16 {
431            let mut inps: Vec<u16> = Vec::new();
432            for _ in 0..n {
433                inps.push(rng.gen_bool() as u16);
434            }
435            let res = inps.iter().fold(0, |acc, &x| x | acc);
436            let out = eval_plain(&c, &[], &inps).unwrap()[0];
437            if out != res {
438                println!("{:?} {} {}", inps, out, res);
439                panic!();
440            }
441        }
442    }
443
444    #[test] // {{{ or_gate_fan_n_arithmetic
445    fn or_gate_fan_n_arithmetic() {
446        let mut rng = thread_rng();
447        let n = 2 + (rng.gen_usize() % 200);
448
449        let c = Channel::with(std::io::empty(), |channel| {
450            let mut b: CircuitBuilder<ArithmeticCircuit> = CircuitBuilder::new();
451            let inps = b.evaluator_inputs(&vec![2; n]);
452            let z = b.or_many(&inps, channel).unwrap();
453            b.output(&z, channel).unwrap();
454            let c = b.finish();
455            Ok(c)
456        })
457        .unwrap();
458
459        for _ in 0..16 {
460            let mut inps: Vec<u16> = Vec::new();
461            for _ in 0..n {
462                inps.push(rng.gen_bool() as u16);
463            }
464            let res = inps.iter().fold(0, |acc, &x| x | acc);
465            let out = eval_plain(&c, &[], &inps).unwrap()[0];
466            if out != res {
467                println!("{:?} {} {}", inps, out, res);
468                panic!();
469            }
470        }
471    }
472    //}}}
473    #[test] // {{{ half_gate
474    fn binary_half_gate() {
475        let mut rng = thread_rng();
476        let q = 2;
477
478        let c = Channel::with(std::io::empty(), |channel| {
479            let mut b = CircuitBuilder::<BinaryCircuit>::new();
480            let x = b.garbler_input(q);
481            let y = b.evaluator_input(q);
482            let z = b.and(&x, &y, channel).unwrap();
483            b.output(&z, channel).unwrap();
484            let c = b.finish();
485            Ok(c)
486        })
487        .unwrap();
488        for _ in 0..16 {
489            let x = rng.gen_u16() % q;
490            let y = rng.gen_u16() % q;
491            let out = eval_plain(&c, &[x], &[y]).unwrap();
492            assert_eq!(out[0], x * y % q);
493        }
494    }
495    #[test] // {{{ half_gate
496    fn arithmetic_half_gate() {
497        let mut rng = thread_rng();
498        let q = rng.gen_prime();
499
500        let c = Channel::with(std::io::empty(), |channel| {
501            let mut b = CircuitBuilder::new();
502            let x = b.garbler_input(q);
503            let y = b.evaluator_input(q);
504            let z = b.mul(&x, &y, channel).unwrap();
505            b.output(&z, channel).unwrap();
506            let c = b.finish();
507            Ok(c)
508        })
509        .unwrap();
510        for _ in 0..16 {
511            let x = rng.gen_u16() % q;
512            let y = rng.gen_u16() % q;
513            let out = eval_plain(&c, &[x], &[y]).unwrap();
514            assert_eq!(out[0], x * y % q);
515        }
516    }
517    //}}}
518    #[test] // mod_change {{{
519    fn mod_change() {
520        let mut rng = thread_rng();
521        let p = rng.gen_prime();
522        let q = rng.gen_prime();
523
524        let c = Channel::with(std::io::empty(), |channel| {
525            let mut b = CircuitBuilder::new();
526            let x = b.garbler_input(p);
527            let y = b.mod_change(&x, q, channel).unwrap();
528            let z = b.mod_change(&y, p, channel).unwrap();
529            b.output(&z, channel).unwrap();
530            let c = b.finish();
531            Ok(c)
532        })
533        .unwrap();
534        for _ in 0..16 {
535            let x = rng.gen_u16() % p;
536            let out = eval_plain(&c, &[x], &[]).unwrap();
537            assert_eq!(out[0], x % q);
538        }
539    }
540    //}}}
541    #[test] // add_many_mod_change {{{
542    fn add_many_mod_change() {
543        let c = Channel::with(std::io::empty(), |channel| {
544            let mut b = CircuitBuilder::new();
545            let n = 113;
546            let args = b.garbler_inputs(&vec![2; n]);
547            let wires = args
548                .iter()
549                .map(|x| b.mod_change(x, n as u16 + 1, channel).unwrap())
550                .collect_vec();
551            let s = b.add_many(&wires);
552            b.output(&s, channel).unwrap();
553            let c = b.finish();
554            Ok(c)
555        })
556        .unwrap();
557
558        let mut rng = thread_rng();
559        for _ in 0..64 {
560            let inps = (0..c.num_garbler_inputs())
561                .map(|i| rng.gen_u16() % c.garbler_input_mod(i))
562                .collect_vec();
563            let s: u16 = inps.iter().sum();
564            println!("{:?}, sum={}", inps, s);
565            let out = eval_plain(&c, &inps, &[]).unwrap();
566            assert_eq!(out[0], s);
567        }
568    }
569    // }}}
570    #[test] // constants {{{
571    fn constants() {
572        let mut rng = thread_rng();
573        let q = rng.gen_modulus();
574        let c = rng.gen_u16() % q;
575
576        let circ = Channel::with(std::io::empty(), |channel| {
577            let mut b = CircuitBuilder::new();
578
579            let x = b.evaluator_input(q);
580            let y = b.constant(c, q, channel).unwrap();
581            let z = b.add(&x, &y);
582            b.output(&z, channel).unwrap();
583
584            let circ = b.finish();
585            Ok(circ)
586        })
587        .unwrap();
588
589        for _ in 0..64 {
590            let x = rng.gen_u16() % q;
591            let z = eval_plain(&circ, &[], &[x]).unwrap();
592            assert_eq!(z[0], (x + c) % q);
593        }
594    }
595    //}}}
596}
597
598#[cfg(test)]
599mod bundle {
600    use super::*;
601    use crate::{
602        fancy::{ArithmeticBundleGadgets, BinaryGadgets, BundleGadgets, CrtGadgets},
603        util::{self, RngExt, crt_factor, crt_inv_factor},
604    };
605    use itertools::Itertools;
606    use rand::thread_rng;
607
608    #[test] // bundle input and output {{{
609    fn test_bundle_input_output() {
610        let mut rng = thread_rng();
611        let q = rng.gen_usable_composite_modulus();
612
613        let c = Channel::with(std::io::empty(), |channel| {
614            let mut b = CircuitBuilder::new();
615            let x = b.crt_garbler_input(q);
616            println!("{:?} wires", x.wires().len());
617            b.output_bundle(&x, channel).unwrap();
618            let c: ArithmeticCircuit = b.finish();
619            Ok(c)
620        })
621        .unwrap();
622
623        println!("{:?}", c.output_refs);
624
625        for _ in 0..16 {
626            let x = rng.gen_u128() % q;
627            let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
628            println!("{:?}", res);
629            let z = crt_inv_factor(&res, q);
630            assert_eq!(x, z);
631        }
632    }
633
634    //}}}
635    #[test] // bundle addition {{{
636    fn test_addition() {
637        let mut rng = thread_rng();
638        let q = rng.gen_usable_composite_modulus();
639
640        let c = Channel::with(std::io::empty(), |channel| {
641            let mut b = CircuitBuilder::new();
642            let x = b.crt_garbler_input(q);
643            let y = b.crt_evaluator_input(q);
644            let z = b.crt_add(&x, &y);
645            b.output_bundle(&z, channel).unwrap();
646            let c = b.finish();
647            Ok(c)
648        })
649        .unwrap();
650
651        for _ in 0..16 {
652            let x = rng.gen_u128() % q;
653            let y = rng.gen_u128() % q;
654            let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
655            let z = crt_inv_factor(&res, q);
656            assert_eq!(z, (x + y) % q);
657        }
658    }
659    //}}}
660    #[test] // bundle subtraction {{{
661    fn test_subtraction() {
662        let mut rng = thread_rng();
663        let q = rng.gen_usable_composite_modulus();
664
665        let c = Channel::with(std::io::empty(), |channel| {
666            let mut b = CircuitBuilder::new();
667            let x = b.crt_garbler_input(q);
668            let y = b.crt_evaluator_input(q);
669            let z = b.sub_bundles(&x, &y);
670            b.output_bundle(&z, channel).unwrap();
671            let c = b.finish();
672            Ok(c)
673        })
674        .unwrap();
675
676        for _ in 0..16 {
677            let x = rng.gen_u128() % q;
678            let y = rng.gen_u128() % q;
679            let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
680            let z = crt_inv_factor(&res, q);
681            assert_eq!(z, (x + q - y) % q);
682        }
683    }
684    //}}}
685    #[test] // bundle cmul {{{
686    fn test_cmul() {
687        let mut rng = thread_rng();
688        let q = util::modulus_with_width(16);
689        let y = rng.gen_u128() % q;
690
691        let c = Channel::with(std::io::empty(), |channel| {
692            let mut b = CircuitBuilder::new();
693            let x = b.crt_garbler_input(q);
694            let z = b.crt_cmul(&x, y);
695            b.output_bundle(&z, channel).unwrap();
696            let c = b.finish();
697            Ok(c)
698        })
699        .unwrap();
700
701        for _ in 0..16 {
702            let x = rng.gen_u128() % q;
703            let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
704            let z = crt_inv_factor(&res, q);
705            assert_eq!(z, (x * y) % q);
706        }
707    }
708    //}}}
709    #[test] // bundle multiplication {{{
710    fn test_multiplication() {
711        let mut rng = thread_rng();
712        let q = rng.gen_usable_composite_modulus();
713
714        let c = Channel::with(std::io::empty(), |channel| {
715            let mut b = CircuitBuilder::new();
716            let x = b.crt_garbler_input(q);
717            let y = b.crt_evaluator_input(q);
718            let z = b.mul_bundles(&x, &y, channel).unwrap();
719            b.output_bundle(&z, channel).unwrap();
720            let c = b.finish();
721            Ok(c)
722        })
723        .unwrap();
724
725        for _ in 0..16 {
726            let x = rng.gen_u64() as u128 % q;
727            let y = rng.gen_u64() as u128 % q;
728            let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
729            let z = crt_inv_factor(&res, q);
730            assert_eq!(z, (x * y) % q);
731        }
732    }
733    // }}}
734    #[test] // bundle cexp {{{
735    fn test_cexp() {
736        let mut rng = thread_rng();
737        let q = util::modulus_with_width(10);
738        let y = rng.gen_u16() % 10;
739
740        let c = Channel::with(std::io::empty(), |channel| {
741            let mut b = CircuitBuilder::new();
742            let x = b.crt_garbler_input(q);
743            let z = b.crt_cexp(&x, y, channel).unwrap();
744            b.output_bundle(&z, channel).unwrap();
745            let c = b.finish();
746            Ok(c)
747        })
748        .unwrap();
749
750        for _ in 0..64 {
751            let x = rng.gen_u16() as u128 % q;
752            let should_be = x.pow(y as u32) % q;
753            let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
754            let z = crt_inv_factor(&res, q);
755            assert_eq!(z, should_be);
756        }
757    }
758    // }}}
759    #[test] // bundle remainder {{{
760    fn test_remainder() {
761        let mut rng = thread_rng();
762        let ps = rng.gen_usable_factors();
763        let q = ps.iter().fold(1, |acc, &x| (x as u128) * acc);
764        let p = ps[rng.gen_u16() as usize % ps.len()];
765
766        let c = Channel::with(std::io::empty(), |channel| {
767            let mut b = CircuitBuilder::new();
768            let x = b.crt_garbler_input(q);
769            let z = b.crt_rem(&x, p, channel).unwrap();
770            b.output_bundle(&z, channel).unwrap();
771            let c = b.finish();
772            Ok(c)
773        })
774        .unwrap();
775
776        for _ in 0..64 {
777            let x = rng.gen_u128() % q;
778            let should_be = x % p as u128;
779            let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
780            let z = crt_inv_factor(&res, q);
781            assert_eq!(z, should_be);
782        }
783    }
784    //}}}
785    #[test] // bundle equality {{{
786    fn test_equality() {
787        let mut rng = thread_rng();
788        let q = rng.gen_usable_composite_modulus();
789
790        let c = Channel::with(std::io::empty(), |channel| {
791            let mut b = CircuitBuilder::new();
792            let x = b.crt_garbler_input(q);
793            let y = b.crt_evaluator_input(q);
794            let z = b.eq_bundles(&x, &y, channel).unwrap();
795            b.output(&z, channel).unwrap();
796            let c = b.finish();
797            Ok(c)
798        })
799        .unwrap();
800
801        // lets have at least one test where they are surely equal
802        let x = rng.gen_u128() % q;
803        let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(x, q)).unwrap();
804        assert_eq!(res, &[(x == x) as u16]);
805
806        for _ in 0..64 {
807            let x = rng.gen_u128() % q;
808            let y = rng.gen_u128() % q;
809            let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
810            assert_eq!(res, &[(x == y) as u16]);
811        }
812    }
813    //}}}
814    #[test] // bundle mixed_radix_addition {{{
815    fn test_mixed_radix_addition() {
816        let mut rng = thread_rng();
817
818        let nargs = 2 + rng.gen_usize() % 100;
819        let mods = (0..7).map(|_| rng.gen_modulus()).collect_vec();
820
821        let circ = Channel::with(std::io::empty(), |channel| {
822            let mut b = CircuitBuilder::new();
823            let xs = (0..nargs)
824                .map(|_| crate::fancy::Bundle::new(b.evaluator_inputs(&mods)))
825                .collect_vec();
826            let z = b.mixed_radix_addition(&xs, channel).unwrap();
827            b.output_bundle(&z, channel).unwrap();
828            let circ = b.finish();
829            Ok(circ)
830        })
831        .unwrap();
832
833        let Q: u128 = mods.iter().map(|&q| q as u128).product();
834
835        // test maximum overflow
836        let mut ds = Vec::new();
837        for _ in 0..nargs {
838            ds.extend(util::as_mixed_radix(Q - 1, &mods).iter());
839        }
840        let res = eval_plain(&circ, &[], &ds).unwrap();
841        assert_eq!(
842            util::from_mixed_radix(&res, &mods),
843            (Q - 1) * (nargs as u128) % Q
844        );
845
846        // test random values
847        for _ in 0..4 {
848            let mut should_be = 0;
849            let mut ds = Vec::new();
850            for _ in 0..nargs {
851                let x = rng.gen_u128() % Q;
852                should_be = (should_be + x) % Q;
853                ds.extend(util::as_mixed_radix(x, &mods).iter());
854            }
855            let res = eval_plain(&circ, &[], &ds).unwrap();
856            assert_eq!(util::from_mixed_radix(&res, &mods), should_be);
857        }
858    }
859    //}}}
860    #[test] // bundle relu {{{
861    fn test_relu() {
862        let mut rng = thread_rng();
863        let q = util::modulus_with_width(10);
864        println!("q={}", q);
865
866        let c = Channel::with(std::io::empty(), |channel| {
867            let mut b = CircuitBuilder::new();
868            let x = b.crt_garbler_input(q);
869            let z = b.crt_relu(&x, "100%", None, channel).unwrap();
870            b.output_bundle(&z, channel).unwrap();
871            let c = b.finish();
872            Ok(c)
873        })
874        .unwrap();
875
876        for _ in 0..128 {
877            let pt = rng.gen_u128() % q;
878            let should_be = if pt < q / 2 { pt } else { 0 };
879            let res = eval_plain(&c, &crt_factor(pt, q), &[]).unwrap();
880            let z = crt_inv_factor(&res, q);
881            assert_eq!(z, should_be);
882        }
883    }
884    //}}}
885    #[test] // bundle sgn {{{
886    fn test_sgn() {
887        let mut rng = thread_rng();
888        let q = util::modulus_with_width(10);
889        println!("q={}", q);
890
891        let c = Channel::with(std::io::empty(), |channel| {
892            let mut b = CircuitBuilder::new();
893            let x = b.crt_garbler_input(q);
894            let z = b.crt_sgn(&x, "100%", None, channel).unwrap();
895            b.output_bundle(&z, channel).unwrap();
896            let c = b.finish();
897            Ok(c)
898        })
899        .unwrap();
900
901        for _ in 0..128 {
902            let pt = rng.gen_u128() % q;
903            let should_be = if pt < q / 2 { 1 } else { q - 1 };
904            let res = eval_plain(&c, &crt_factor(pt, q), &[]).unwrap();
905            let z = crt_inv_factor(&res, q);
906            assert_eq!(z, should_be);
907        }
908    }
909    //}}}
910    #[test] // bundle leq {{{
911    fn test_leq() {
912        let mut rng = thread_rng();
913        let q = util::modulus_with_width(10);
914
915        let c = Channel::with(std::io::empty(), |channel| {
916            let mut b = CircuitBuilder::new();
917            let x = b.crt_garbler_input(q);
918            let y = b.crt_evaluator_input(q);
919            let z = b.crt_lt(&x, &y, "100%", channel).unwrap();
920            b.output(&z, channel).unwrap();
921            let c = b.finish();
922            Ok(c)
923        })
924        .unwrap();
925
926        // lets have at least one test where they are surely equal
927        let x = rng.gen_u128() % q / 2;
928        let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(x, q)).unwrap();
929        assert_eq!(res, &[(x < x) as u16], "x={}", x);
930
931        for _ in 0..64 {
932            let x = rng.gen_u128() % q / 2;
933            let y = rng.gen_u128() % q / 2;
934            let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
935            assert_eq!(res, &[(x < y) as u16], "x={} y={}", x, y);
936        }
937    }
938    //}}}
939    #[test] // bundle max {{{
940    fn test_max() {
941        let mut rng = thread_rng();
942        let q = util::modulus_with_width(10);
943        let n = 10;
944        println!("n={} q={}", n, q);
945
946        let c = Channel::with(std::io::empty(), |channel| {
947            let mut b = CircuitBuilder::new();
948            let xs = (0..n).map(|_| b.crt_garbler_input(q)).collect_vec();
949            let z = b.crt_max(&xs, "100%", channel).unwrap();
950            b.output_bundle(&z, channel).unwrap();
951            let c = b.finish();
952            Ok(c)
953        })
954        .unwrap();
955
956        for _ in 0..16 {
957            let inps = (0..n).map(|_| rng.gen_u128() % (q / 2)).collect_vec();
958            println!("{:?}", inps);
959            let should_be = *inps.iter().max().unwrap();
960
961            let enc_inps = inps
962                .into_iter()
963                .flat_map(|x| crt_factor(x, q))
964                .collect_vec();
965            let res = eval_plain(&c, &enc_inps, &[]).unwrap();
966            let z = crt_inv_factor(&res, q);
967            assert_eq!(z, should_be);
968        }
969    }
970    //}}}
971    #[test] // binary addition {{{
972    fn test_binary_addition() {
973        let mut rng = thread_rng();
974        let n = 2 + (rng.gen_usize() % 10);
975        let q = 2;
976        let Q = util::product(&vec![q; n]);
977        println!("n={} q={} Q={}", n, q, Q);
978
979        let c = Channel::with(std::io::empty(), |channel| {
980            let mut b = CircuitBuilder::<BinaryCircuit>::new();
981            let x = b.bin_garbler_input(n);
982            let y = b.bin_evaluator_input(n);
983            let (zs, carry) = b.bin_addition(&x, &y, channel).unwrap();
984            b.output(&carry, channel).unwrap();
985            b.output_bundle(&zs, channel).unwrap();
986            let c = b.finish();
987            Ok(c)
988        })
989        .unwrap();
990
991        for _ in 0..16 {
992            let x = rng.gen_u128() % Q;
993            let y = rng.gen_u128() % Q;
994            println!("x={} y={}", x, y);
995            let res_should_be = (x + y) % Q;
996            let carry_should_be = (x + y >= Q) as u16;
997            let res = eval_plain(&c, &util::u128_to_bits(x, n), &util::u128_to_bits(y, n)).unwrap();
998            assert_eq!(util::u128_from_bits(&res[1..]), res_should_be);
999            assert_eq!(res[0], carry_should_be);
1000        }
1001    }
1002    //}}}
1003    #[test] // binary demux {{{
1004    fn test_bin_demux() {
1005        let mut rng = thread_rng();
1006        let nbits = 1 + (rng.gen_usize() % 7);
1007        let Q = 1 << nbits as u128;
1008
1009        let c = Channel::with(std::io::empty(), |channel| {
1010            let mut b = CircuitBuilder::<BinaryCircuit>::new();
1011            let x = b.bin_garbler_input(nbits);
1012            let d = b.bin_demux(&x, channel).unwrap();
1013            b.outputs(&d, channel).unwrap();
1014            let c = b.finish();
1015            Ok(c)
1016        })
1017        .unwrap();
1018
1019        for _ in 0..16 {
1020            let x = rng.gen_u128() % Q;
1021            println!("x={}", x);
1022            let mut should_be = vec![0; Q as usize];
1023            should_be[x as usize] = 1;
1024
1025            let res = eval_plain(&c, &util::u128_to_bits(x, nbits), &[]).unwrap();
1026
1027            for (i, y) in res.into_iter().enumerate() {
1028                if i as u128 == x {
1029                    assert_eq!(y, 1);
1030                } else {
1031                    assert_eq!(y, 0);
1032                }
1033            }
1034        }
1035    }
1036    //}}}
1037}