1use crate::{
5 dummy::{Dummy, DummyVal},
6 fancy::{BinaryBundle, CrtBundle, Fancy, FancyInput, HasModulus},
7 informer::Informer,
8};
9use itertools::Itertools;
10use std::{collections::HashMap, fmt::Display};
11use swanky_channel::Channel;
12
13mod binary;
14pub use binary::{BinaryCircuit, BinaryGate};
15mod arithmetic;
16pub use arithmetic::{ArithmeticCircuit, ArithmeticGate};
17
18#[derive(Clone, Copy, Debug, PartialEq)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct CircuitRef {
22 pub(crate) ix: usize,
23 pub(crate) modulus: u16,
24}
25
26impl std::fmt::Display for CircuitRef {
27 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28 write!(f, "[{} | {}]", self.ix, self.modulus)
29 }
30}
31
32impl HasModulus for CircuitRef {
33 fn modulus(&self) -> u16 {
34 self.modulus
35 }
36}
37
38pub trait CircuitInfo {
43 fn print_info(&self) -> swanky_error::Result<()>;
45}
46
47impl<C: EvaluableCircuit<Informer<Dummy>>> CircuitInfo for C {
48 fn print_info(&self) -> swanky_error::Result<()> {
49 let mut informer = crate::informer::Informer::new(Dummy::new());
50
51 let gb = Channel::with(std::io::empty(), |channel| {
53 self.get_garbler_input_refs()
54 .iter()
55 .map(|r| informer.encode(0, r.modulus(), channel))
56 .collect::<swanky_error::Result<Vec<DummyVal>>>()
57 })?;
58 let ev = Channel::with(std::io::empty(), |channel| {
59 self.get_evaluator_input_refs()
60 .iter()
61 .map(|r| informer.encode(0, r.modulus(), channel))
62 .collect::<swanky_error::Result<Vec<DummyVal>>>()
63 })?;
64
65 Channel::with(std::io::empty(), |c| self.eval(&mut informer, &gb, &ev, c))?;
66 println!("{}", informer.stats());
67 Ok(())
68 }
69}
70
71pub trait EvaluableCircuit<F: Fancy>: CircuitType {
75 fn eval(
83 &self,
84 f: &mut F,
85 garbler_inputs: &[F::Item],
86 evaluator_inputs: &[F::Item],
87 channel: &mut Channel,
88 ) -> swanky_error::Result<Option<Vec<u16>>> {
89 let wirelabels = self.eval_to_wirelabels(f, garbler_inputs, evaluator_inputs, channel)?;
90 f.outputs(&wirelabels, channel)
91 }
92
93 fn eval_to_wirelabels(
101 &self,
102 f: &mut F,
103 garbler_inputs: &[F::Item],
104 evaluator_inputs: &[F::Item],
105 channel: &mut Channel,
106 ) -> swanky_error::Result<Vec<F::Item>>;
107}
108
109pub trait GateType: Display {
111 fn make_constant(val: u16) -> Self;
113
114 fn make_garbler_input(id: usize) -> Self;
116
117 fn make_evaluator_input(id: usize) -> Self;
119}
120
121impl GateType for ArithmeticGate {
122 fn make_constant(val: u16) -> Self {
123 Self::Constant { val }
124 }
125
126 fn make_garbler_input(id: usize) -> Self {
127 Self::GarblerInput { id }
128 }
129
130 fn make_evaluator_input(id: usize) -> Self {
131 Self::EvaluatorInput { id }
132 }
133}
134
135pub trait CircuitType {
137 type Gate: GateType;
139
140 fn increment_nonfree_gates(&mut self);
142
143 fn new(ngates: Option<usize>) -> Self;
145
146 fn get_output_refs(&self) -> &[CircuitRef];
148
149 fn get_garbler_input_refs(&self) -> &[CircuitRef];
151
152 fn get_evaluator_input_refs(&self) -> &[CircuitRef];
154
155 fn get_num_nonfree_gates(&self) -> usize;
157
158 fn push_gates(&mut self, gate: Self::Gate);
160
161 fn push_const_ref(&mut self, xref: CircuitRef);
163
164 fn push_output_ref(&mut self, xref: CircuitRef);
166
167 fn push_garbler_input_ref(&mut self, xref: CircuitRef);
169
170 fn push_evaluator_input_ref(&mut self, xref: CircuitRef);
172
173 fn push_modulus(&mut self, modulus: u16);
175
176 fn garbler_input_mod(&self, i: usize) -> u16;
178
179 fn evaluator_input_mod(&self, i: usize) -> u16;
181
182 #[inline]
184 fn num_garbler_inputs(&self) -> usize {
185 self.get_garbler_input_refs().len()
186 }
187
188 #[inline]
190 fn num_evaluator_inputs(&self) -> usize {
191 self.get_evaluator_input_refs().len()
192 }
193
194 #[inline]
196 fn noutputs(&self) -> usize {
197 self.get_output_refs().len()
198 }
199}
200
201pub fn eval_plain<C: EvaluableCircuit<Dummy>>(
207 circuit: &C,
208 garbler_inputs: &[u16],
209 evaluator_inputs: &[u16],
210) -> swanky_error::Result<Vec<u16>> {
211 assert_eq!(garbler_inputs.len(), circuit.num_garbler_inputs());
212 assert_eq!(evaluator_inputs.len(), circuit.num_evaluator_inputs());
213
214 let mut dummy = crate::dummy::Dummy::new();
215
216 let gb = garbler_inputs
218 .iter()
219 .zip(circuit.get_garbler_input_refs().iter())
220 .map(|(x, r)| DummyVal::new(*x, r.modulus()))
221 .collect_vec();
222 let ev = evaluator_inputs
223 .iter()
224 .zip(circuit.get_evaluator_input_refs().iter())
225 .map(|(x, r)| DummyVal::new(*x, r.modulus()))
226 .collect_vec();
227
228 let outputs = Channel::with(std::io::empty(), |c| {
230 Ok(circuit.eval(&mut dummy, &gb, &ev, c).unwrap())
232 })
233 .unwrap();
234 Ok(outputs.expect("dummy will always return Some(u16) output"))
235}
236
237pub struct CircuitBuilder<Circuit> {
239 next_ref_ix: usize,
240 next_garbler_input_id: usize,
241 next_evaluator_input_id: usize,
242 const_map: HashMap<(u16, u16), CircuitRef>,
243 circ: Circuit,
244}
245
246impl<Circuit: CircuitType> Fancy for CircuitBuilder<Circuit> {
247 type Item = CircuitRef;
248
249 fn constant(
250 &mut self,
251 val: u16,
252 modulus: u16,
253 _: &mut Channel,
254 ) -> swanky_error::Result<CircuitRef> {
255 Ok(self.lookup_constant(val, modulus))
256 }
257
258 fn output(&mut self, xref: &CircuitRef, _: &mut Channel) -> swanky_error::Result<Option<u16>> {
259 self.circ.push_output_ref(*xref);
260 Ok(None)
261 }
262}
263
264impl<Circuit: CircuitType> CircuitBuilder<Circuit> {
265 pub fn new() -> Self {
267 CircuitBuilder {
268 next_ref_ix: 0,
269 next_garbler_input_id: 0,
270 next_evaluator_input_id: 0,
271 const_map: HashMap::new(),
272 circ: Circuit::new(None),
273 }
274 }
275
276 pub fn finish(self) -> Circuit {
278 self.circ
279 }
280
281 fn lookup_constant(&mut self, val: u16, modulus: u16) -> CircuitRef {
284 match self.const_map.get(&(val, modulus)) {
285 Some(&r) => r,
286 None => {
287 let gate = Circuit::Gate::make_constant(val);
288 let r = self.gate(gate, modulus);
289 self.const_map.insert((val, modulus), r);
290 self.circ.push_const_ref(r);
291 r
292 }
293 }
294 }
295
296 fn get_next_garbler_input_id(&mut self) -> usize {
297 let current = self.next_garbler_input_id;
298 self.next_garbler_input_id += 1;
299 current
300 }
301
302 fn get_next_evaluator_input_id(&mut self) -> usize {
303 let current = self.next_evaluator_input_id;
304 self.next_evaluator_input_id += 1;
305 current
306 }
307
308 fn get_next_ciphertext_id(&mut self) -> usize {
309 let current = self.circ.get_num_nonfree_gates();
310 self.circ.increment_nonfree_gates();
311 current
312 }
313
314 fn get_next_ref_ix(&mut self) -> usize {
315 let current = self.next_ref_ix;
316 self.next_ref_ix += 1;
317 current
318 }
319
320 fn gate(&mut self, gate: Circuit::Gate, modulus: u16) -> CircuitRef {
321 self.circ.push_gates(gate);
322 self.circ.push_modulus(modulus);
323 let ix = self.get_next_ref_ix();
324 CircuitRef { ix, modulus }
325 }
326
327 pub fn garbler_input(&mut self, modulus: u16) -> CircuitRef {
329 let id = self.get_next_garbler_input_id();
330 let r = self.gate(Circuit::Gate::make_garbler_input(id), modulus);
331 self.circ.push_garbler_input_ref(r);
332 r
333 }
334
335 pub fn evaluator_input(&mut self, modulus: u16) -> CircuitRef {
337 let id = self.get_next_evaluator_input_id();
338 let r = self.gate(Circuit::Gate::make_evaluator_input(id), modulus);
339 self.circ.push_evaluator_input_ref(r);
340 r
341 }
342
343 pub fn garbler_inputs(&mut self, mods: &[u16]) -> Vec<CircuitRef> {
345 mods.iter().map(|q| self.garbler_input(*q)).collect()
346 }
347
348 pub fn evaluator_inputs(&mut self, mods: &[u16]) -> Vec<CircuitRef> {
350 mods.iter().map(|q| self.evaluator_input(*q)).collect()
351 }
352
353 pub fn crt_garbler_input(&mut self, modulus: u128) -> CrtBundle<CircuitRef> {
355 CrtBundle::new(self.garbler_inputs(&crate::util::factor(modulus)))
356 }
357
358 pub fn crt_evaluator_input(&mut self, modulus: u128) -> CrtBundle<CircuitRef> {
360 CrtBundle::new(self.evaluator_inputs(&crate::util::factor(modulus)))
361 }
362
363 pub fn bin_garbler_input(&mut self, nbits: usize) -> BinaryBundle<CircuitRef> {
365 BinaryBundle::new(self.garbler_inputs(&vec![2; nbits]))
366 }
367
368 pub fn bin_evaluator_input(&mut self, nbits: usize) -> BinaryBundle<CircuitRef> {
370 BinaryBundle::new(self.evaluator_inputs(&vec![2; nbits]))
371 }
372}
373
374impl<Circuit: CircuitType> Default for CircuitBuilder<Circuit> {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380#[cfg(test)]
381mod plaintext {
382 use super::*;
383 use crate::{FancyArithmetic, FancyBinary, util::RngExt};
384 use itertools::Itertools;
385 use rand::thread_rng;
386
387 #[test] fn and_gate_fan_n() {
389 let mut rng = thread_rng();
390 let n = 2 + (rng.gen_usize() % 200);
391
392 let c = Channel::with(std::io::empty(), |channel| {
393 let mut b = CircuitBuilder::<BinaryCircuit>::new();
394 let inps = b.evaluator_inputs(&vec![2; n]);
395 let z = b.and_many(&inps, channel).unwrap();
396 b.output(&z, channel).unwrap();
397 let c = b.finish();
398 Ok(c)
399 })
400 .unwrap();
401
402 for _ in 0..16 {
403 let mut inps: Vec<u16> = Vec::new();
404 for _ in 0..n {
405 inps.push(rng.gen_bool() as u16);
406 }
407 let res = inps.iter().fold(1, |acc, &x| x & acc);
408 let out = eval_plain(&c, &[], &inps).unwrap()[0];
409 if out != res {
410 println!("{:?} {} {}", inps, out, res);
411 panic!("incorrect output n={}", n);
412 }
413 }
414 }
415 #[test] fn or_gate_fan_n() {
418 let mut rng = thread_rng();
419 let n = 2 + (rng.gen_usize() % 200);
420 let c = Channel::with(std::io::empty(), |channel| {
421 let mut b: CircuitBuilder<BinaryCircuit> = CircuitBuilder::new();
422 let inps = b.evaluator_inputs(&vec![2; n]);
423 let z = b.or_many(&inps, channel).unwrap();
424 b.output(&z, channel).unwrap();
425 let c = b.finish();
426 Ok(c)
427 })
428 .unwrap();
429
430 for _ in 0..16 {
431 let mut inps: Vec<u16> = Vec::new();
432 for _ in 0..n {
433 inps.push(rng.gen_bool() as u16);
434 }
435 let res = inps.iter().fold(0, |acc, &x| x | acc);
436 let out = eval_plain(&c, &[], &inps).unwrap()[0];
437 if out != res {
438 println!("{:?} {} {}", inps, out, res);
439 panic!();
440 }
441 }
442 }
443
444 #[test] fn or_gate_fan_n_arithmetic() {
446 let mut rng = thread_rng();
447 let n = 2 + (rng.gen_usize() % 200);
448
449 let c = Channel::with(std::io::empty(), |channel| {
450 let mut b: CircuitBuilder<ArithmeticCircuit> = CircuitBuilder::new();
451 let inps = b.evaluator_inputs(&vec![2; n]);
452 let z = b.or_many(&inps, channel).unwrap();
453 b.output(&z, channel).unwrap();
454 let c = b.finish();
455 Ok(c)
456 })
457 .unwrap();
458
459 for _ in 0..16 {
460 let mut inps: Vec<u16> = Vec::new();
461 for _ in 0..n {
462 inps.push(rng.gen_bool() as u16);
463 }
464 let res = inps.iter().fold(0, |acc, &x| x | acc);
465 let out = eval_plain(&c, &[], &inps).unwrap()[0];
466 if out != res {
467 println!("{:?} {} {}", inps, out, res);
468 panic!();
469 }
470 }
471 }
472 #[test] fn binary_half_gate() {
475 let mut rng = thread_rng();
476 let q = 2;
477
478 let c = Channel::with(std::io::empty(), |channel| {
479 let mut b = CircuitBuilder::<BinaryCircuit>::new();
480 let x = b.garbler_input(q);
481 let y = b.evaluator_input(q);
482 let z = b.and(&x, &y, channel).unwrap();
483 b.output(&z, channel).unwrap();
484 let c = b.finish();
485 Ok(c)
486 })
487 .unwrap();
488 for _ in 0..16 {
489 let x = rng.gen_u16() % q;
490 let y = rng.gen_u16() % q;
491 let out = eval_plain(&c, &[x], &[y]).unwrap();
492 assert_eq!(out[0], x * y % q);
493 }
494 }
495 #[test] fn arithmetic_half_gate() {
497 let mut rng = thread_rng();
498 let q = rng.gen_prime();
499
500 let c = Channel::with(std::io::empty(), |channel| {
501 let mut b = CircuitBuilder::new();
502 let x = b.garbler_input(q);
503 let y = b.evaluator_input(q);
504 let z = b.mul(&x, &y, channel).unwrap();
505 b.output(&z, channel).unwrap();
506 let c = b.finish();
507 Ok(c)
508 })
509 .unwrap();
510 for _ in 0..16 {
511 let x = rng.gen_u16() % q;
512 let y = rng.gen_u16() % q;
513 let out = eval_plain(&c, &[x], &[y]).unwrap();
514 assert_eq!(out[0], x * y % q);
515 }
516 }
517 #[test] fn mod_change() {
520 let mut rng = thread_rng();
521 let p = rng.gen_prime();
522 let q = rng.gen_prime();
523
524 let c = Channel::with(std::io::empty(), |channel| {
525 let mut b = CircuitBuilder::new();
526 let x = b.garbler_input(p);
527 let y = b.mod_change(&x, q, channel).unwrap();
528 let z = b.mod_change(&y, p, channel).unwrap();
529 b.output(&z, channel).unwrap();
530 let c = b.finish();
531 Ok(c)
532 })
533 .unwrap();
534 for _ in 0..16 {
535 let x = rng.gen_u16() % p;
536 let out = eval_plain(&c, &[x], &[]).unwrap();
537 assert_eq!(out[0], x % q);
538 }
539 }
540 #[test] fn add_many_mod_change() {
543 let c = Channel::with(std::io::empty(), |channel| {
544 let mut b = CircuitBuilder::new();
545 let n = 113;
546 let args = b.garbler_inputs(&vec![2; n]);
547 let wires = args
548 .iter()
549 .map(|x| b.mod_change(x, n as u16 + 1, channel).unwrap())
550 .collect_vec();
551 let s = b.add_many(&wires);
552 b.output(&s, channel).unwrap();
553 let c = b.finish();
554 Ok(c)
555 })
556 .unwrap();
557
558 let mut rng = thread_rng();
559 for _ in 0..64 {
560 let inps = (0..c.num_garbler_inputs())
561 .map(|i| rng.gen_u16() % c.garbler_input_mod(i))
562 .collect_vec();
563 let s: u16 = inps.iter().sum();
564 println!("{:?}, sum={}", inps, s);
565 let out = eval_plain(&c, &inps, &[]).unwrap();
566 assert_eq!(out[0], s);
567 }
568 }
569 #[test] fn constants() {
572 let mut rng = thread_rng();
573 let q = rng.gen_modulus();
574 let c = rng.gen_u16() % q;
575
576 let circ = Channel::with(std::io::empty(), |channel| {
577 let mut b = CircuitBuilder::new();
578
579 let x = b.evaluator_input(q);
580 let y = b.constant(c, q, channel).unwrap();
581 let z = b.add(&x, &y);
582 b.output(&z, channel).unwrap();
583
584 let circ = b.finish();
585 Ok(circ)
586 })
587 .unwrap();
588
589 for _ in 0..64 {
590 let x = rng.gen_u16() % q;
591 let z = eval_plain(&circ, &[], &[x]).unwrap();
592 assert_eq!(z[0], (x + c) % q);
593 }
594 }
595 }
597
598#[cfg(test)]
599mod bundle {
600 use super::*;
601 use crate::{
602 fancy::{ArithmeticBundleGadgets, BinaryGadgets, BundleGadgets, CrtGadgets},
603 util::{self, RngExt, crt_factor, crt_inv_factor},
604 };
605 use itertools::Itertools;
606 use rand::thread_rng;
607
608 #[test] fn test_bundle_input_output() {
610 let mut rng = thread_rng();
611 let q = rng.gen_usable_composite_modulus();
612
613 let c = Channel::with(std::io::empty(), |channel| {
614 let mut b = CircuitBuilder::new();
615 let x = b.crt_garbler_input(q);
616 println!("{:?} wires", x.wires().len());
617 b.output_bundle(&x, channel).unwrap();
618 let c: ArithmeticCircuit = b.finish();
619 Ok(c)
620 })
621 .unwrap();
622
623 println!("{:?}", c.output_refs);
624
625 for _ in 0..16 {
626 let x = rng.gen_u128() % q;
627 let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
628 println!("{:?}", res);
629 let z = crt_inv_factor(&res, q);
630 assert_eq!(x, z);
631 }
632 }
633
634 #[test] fn test_addition() {
637 let mut rng = thread_rng();
638 let q = rng.gen_usable_composite_modulus();
639
640 let c = Channel::with(std::io::empty(), |channel| {
641 let mut b = CircuitBuilder::new();
642 let x = b.crt_garbler_input(q);
643 let y = b.crt_evaluator_input(q);
644 let z = b.crt_add(&x, &y);
645 b.output_bundle(&z, channel).unwrap();
646 let c = b.finish();
647 Ok(c)
648 })
649 .unwrap();
650
651 for _ in 0..16 {
652 let x = rng.gen_u128() % q;
653 let y = rng.gen_u128() % q;
654 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
655 let z = crt_inv_factor(&res, q);
656 assert_eq!(z, (x + y) % q);
657 }
658 }
659 #[test] fn test_subtraction() {
662 let mut rng = thread_rng();
663 let q = rng.gen_usable_composite_modulus();
664
665 let c = Channel::with(std::io::empty(), |channel| {
666 let mut b = CircuitBuilder::new();
667 let x = b.crt_garbler_input(q);
668 let y = b.crt_evaluator_input(q);
669 let z = b.sub_bundles(&x, &y);
670 b.output_bundle(&z, channel).unwrap();
671 let c = b.finish();
672 Ok(c)
673 })
674 .unwrap();
675
676 for _ in 0..16 {
677 let x = rng.gen_u128() % q;
678 let y = rng.gen_u128() % q;
679 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
680 let z = crt_inv_factor(&res, q);
681 assert_eq!(z, (x + q - y) % q);
682 }
683 }
684 #[test] fn test_cmul() {
687 let mut rng = thread_rng();
688 let q = util::modulus_with_width(16);
689 let y = rng.gen_u128() % q;
690
691 let c = Channel::with(std::io::empty(), |channel| {
692 let mut b = CircuitBuilder::new();
693 let x = b.crt_garbler_input(q);
694 let z = b.crt_cmul(&x, y);
695 b.output_bundle(&z, channel).unwrap();
696 let c = b.finish();
697 Ok(c)
698 })
699 .unwrap();
700
701 for _ in 0..16 {
702 let x = rng.gen_u128() % q;
703 let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
704 let z = crt_inv_factor(&res, q);
705 assert_eq!(z, (x * y) % q);
706 }
707 }
708 #[test] fn test_multiplication() {
711 let mut rng = thread_rng();
712 let q = rng.gen_usable_composite_modulus();
713
714 let c = Channel::with(std::io::empty(), |channel| {
715 let mut b = CircuitBuilder::new();
716 let x = b.crt_garbler_input(q);
717 let y = b.crt_evaluator_input(q);
718 let z = b.mul_bundles(&x, &y, channel).unwrap();
719 b.output_bundle(&z, channel).unwrap();
720 let c = b.finish();
721 Ok(c)
722 })
723 .unwrap();
724
725 for _ in 0..16 {
726 let x = rng.gen_u64() as u128 % q;
727 let y = rng.gen_u64() as u128 % q;
728 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
729 let z = crt_inv_factor(&res, q);
730 assert_eq!(z, (x * y) % q);
731 }
732 }
733 #[test] fn test_cexp() {
736 let mut rng = thread_rng();
737 let q = util::modulus_with_width(10);
738 let y = rng.gen_u16() % 10;
739
740 let c = Channel::with(std::io::empty(), |channel| {
741 let mut b = CircuitBuilder::new();
742 let x = b.crt_garbler_input(q);
743 let z = b.crt_cexp(&x, y, channel).unwrap();
744 b.output_bundle(&z, channel).unwrap();
745 let c = b.finish();
746 Ok(c)
747 })
748 .unwrap();
749
750 for _ in 0..64 {
751 let x = rng.gen_u16() as u128 % q;
752 let should_be = x.pow(y as u32) % q;
753 let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
754 let z = crt_inv_factor(&res, q);
755 assert_eq!(z, should_be);
756 }
757 }
758 #[test] fn test_remainder() {
761 let mut rng = thread_rng();
762 let ps = rng.gen_usable_factors();
763 let q = ps.iter().fold(1, |acc, &x| (x as u128) * acc);
764 let p = ps[rng.gen_u16() as usize % ps.len()];
765
766 let c = Channel::with(std::io::empty(), |channel| {
767 let mut b = CircuitBuilder::new();
768 let x = b.crt_garbler_input(q);
769 let z = b.crt_rem(&x, p, channel).unwrap();
770 b.output_bundle(&z, channel).unwrap();
771 let c = b.finish();
772 Ok(c)
773 })
774 .unwrap();
775
776 for _ in 0..64 {
777 let x = rng.gen_u128() % q;
778 let should_be = x % p as u128;
779 let res = eval_plain(&c, &crt_factor(x, q), &[]).unwrap();
780 let z = crt_inv_factor(&res, q);
781 assert_eq!(z, should_be);
782 }
783 }
784 #[test] fn test_equality() {
787 let mut rng = thread_rng();
788 let q = rng.gen_usable_composite_modulus();
789
790 let c = Channel::with(std::io::empty(), |channel| {
791 let mut b = CircuitBuilder::new();
792 let x = b.crt_garbler_input(q);
793 let y = b.crt_evaluator_input(q);
794 let z = b.eq_bundles(&x, &y, channel).unwrap();
795 b.output(&z, channel).unwrap();
796 let c = b.finish();
797 Ok(c)
798 })
799 .unwrap();
800
801 let x = rng.gen_u128() % q;
803 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(x, q)).unwrap();
804 assert_eq!(res, &[(x == x) as u16]);
805
806 for _ in 0..64 {
807 let x = rng.gen_u128() % q;
808 let y = rng.gen_u128() % q;
809 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
810 assert_eq!(res, &[(x == y) as u16]);
811 }
812 }
813 #[test] fn test_mixed_radix_addition() {
816 let mut rng = thread_rng();
817
818 let nargs = 2 + rng.gen_usize() % 100;
819 let mods = (0..7).map(|_| rng.gen_modulus()).collect_vec();
820
821 let circ = Channel::with(std::io::empty(), |channel| {
822 let mut b = CircuitBuilder::new();
823 let xs = (0..nargs)
824 .map(|_| crate::fancy::Bundle::new(b.evaluator_inputs(&mods)))
825 .collect_vec();
826 let z = b.mixed_radix_addition(&xs, channel).unwrap();
827 b.output_bundle(&z, channel).unwrap();
828 let circ = b.finish();
829 Ok(circ)
830 })
831 .unwrap();
832
833 let Q: u128 = mods.iter().map(|&q| q as u128).product();
834
835 let mut ds = Vec::new();
837 for _ in 0..nargs {
838 ds.extend(util::as_mixed_radix(Q - 1, &mods).iter());
839 }
840 let res = eval_plain(&circ, &[], &ds).unwrap();
841 assert_eq!(
842 util::from_mixed_radix(&res, &mods),
843 (Q - 1) * (nargs as u128) % Q
844 );
845
846 for _ in 0..4 {
848 let mut should_be = 0;
849 let mut ds = Vec::new();
850 for _ in 0..nargs {
851 let x = rng.gen_u128() % Q;
852 should_be = (should_be + x) % Q;
853 ds.extend(util::as_mixed_radix(x, &mods).iter());
854 }
855 let res = eval_plain(&circ, &[], &ds).unwrap();
856 assert_eq!(util::from_mixed_radix(&res, &mods), should_be);
857 }
858 }
859 #[test] fn test_relu() {
862 let mut rng = thread_rng();
863 let q = util::modulus_with_width(10);
864 println!("q={}", q);
865
866 let c = Channel::with(std::io::empty(), |channel| {
867 let mut b = CircuitBuilder::new();
868 let x = b.crt_garbler_input(q);
869 let z = b.crt_relu(&x, "100%", None, channel).unwrap();
870 b.output_bundle(&z, channel).unwrap();
871 let c = b.finish();
872 Ok(c)
873 })
874 .unwrap();
875
876 for _ in 0..128 {
877 let pt = rng.gen_u128() % q;
878 let should_be = if pt < q / 2 { pt } else { 0 };
879 let res = eval_plain(&c, &crt_factor(pt, q), &[]).unwrap();
880 let z = crt_inv_factor(&res, q);
881 assert_eq!(z, should_be);
882 }
883 }
884 #[test] fn test_sgn() {
887 let mut rng = thread_rng();
888 let q = util::modulus_with_width(10);
889 println!("q={}", q);
890
891 let c = Channel::with(std::io::empty(), |channel| {
892 let mut b = CircuitBuilder::new();
893 let x = b.crt_garbler_input(q);
894 let z = b.crt_sgn(&x, "100%", None, channel).unwrap();
895 b.output_bundle(&z, channel).unwrap();
896 let c = b.finish();
897 Ok(c)
898 })
899 .unwrap();
900
901 for _ in 0..128 {
902 let pt = rng.gen_u128() % q;
903 let should_be = if pt < q / 2 { 1 } else { q - 1 };
904 let res = eval_plain(&c, &crt_factor(pt, q), &[]).unwrap();
905 let z = crt_inv_factor(&res, q);
906 assert_eq!(z, should_be);
907 }
908 }
909 #[test] fn test_leq() {
912 let mut rng = thread_rng();
913 let q = util::modulus_with_width(10);
914
915 let c = Channel::with(std::io::empty(), |channel| {
916 let mut b = CircuitBuilder::new();
917 let x = b.crt_garbler_input(q);
918 let y = b.crt_evaluator_input(q);
919 let z = b.crt_lt(&x, &y, "100%", channel).unwrap();
920 b.output(&z, channel).unwrap();
921 let c = b.finish();
922 Ok(c)
923 })
924 .unwrap();
925
926 let x = rng.gen_u128() % q / 2;
928 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(x, q)).unwrap();
929 assert_eq!(res, &[(x < x) as u16], "x={}", x);
930
931 for _ in 0..64 {
932 let x = rng.gen_u128() % q / 2;
933 let y = rng.gen_u128() % q / 2;
934 let res = eval_plain(&c, &crt_factor(x, q), &crt_factor(y, q)).unwrap();
935 assert_eq!(res, &[(x < y) as u16], "x={} y={}", x, y);
936 }
937 }
938 #[test] fn test_max() {
941 let mut rng = thread_rng();
942 let q = util::modulus_with_width(10);
943 let n = 10;
944 println!("n={} q={}", n, q);
945
946 let c = Channel::with(std::io::empty(), |channel| {
947 let mut b = CircuitBuilder::new();
948 let xs = (0..n).map(|_| b.crt_garbler_input(q)).collect_vec();
949 let z = b.crt_max(&xs, "100%", channel).unwrap();
950 b.output_bundle(&z, channel).unwrap();
951 let c = b.finish();
952 Ok(c)
953 })
954 .unwrap();
955
956 for _ in 0..16 {
957 let inps = (0..n).map(|_| rng.gen_u128() % (q / 2)).collect_vec();
958 println!("{:?}", inps);
959 let should_be = *inps.iter().max().unwrap();
960
961 let enc_inps = inps
962 .into_iter()
963 .flat_map(|x| crt_factor(x, q))
964 .collect_vec();
965 let res = eval_plain(&c, &enc_inps, &[]).unwrap();
966 let z = crt_inv_factor(&res, q);
967 assert_eq!(z, should_be);
968 }
969 }
970 #[test] fn test_binary_addition() {
973 let mut rng = thread_rng();
974 let n = 2 + (rng.gen_usize() % 10);
975 let q = 2;
976 let Q = util::product(&vec![q; n]);
977 println!("n={} q={} Q={}", n, q, Q);
978
979 let c = Channel::with(std::io::empty(), |channel| {
980 let mut b = CircuitBuilder::<BinaryCircuit>::new();
981 let x = b.bin_garbler_input(n);
982 let y = b.bin_evaluator_input(n);
983 let (zs, carry) = b.bin_addition(&x, &y, channel).unwrap();
984 b.output(&carry, channel).unwrap();
985 b.output_bundle(&zs, channel).unwrap();
986 let c = b.finish();
987 Ok(c)
988 })
989 .unwrap();
990
991 for _ in 0..16 {
992 let x = rng.gen_u128() % Q;
993 let y = rng.gen_u128() % Q;
994 println!("x={} y={}", x, y);
995 let res_should_be = (x + y) % Q;
996 let carry_should_be = (x + y >= Q) as u16;
997 let res = eval_plain(&c, &util::u128_to_bits(x, n), &util::u128_to_bits(y, n)).unwrap();
998 assert_eq!(util::u128_from_bits(&res[1..]), res_should_be);
999 assert_eq!(res[0], carry_should_be);
1000 }
1001 }
1002 #[test] fn test_bin_demux() {
1005 let mut rng = thread_rng();
1006 let nbits = 1 + (rng.gen_usize() % 7);
1007 let Q = 1 << nbits as u128;
1008
1009 let c = Channel::with(std::io::empty(), |channel| {
1010 let mut b = CircuitBuilder::<BinaryCircuit>::new();
1011 let x = b.bin_garbler_input(nbits);
1012 let d = b.bin_demux(&x, channel).unwrap();
1013 b.outputs(&d, channel).unwrap();
1014 let c = b.finish();
1015 Ok(c)
1016 })
1017 .unwrap();
1018
1019 for _ in 0..16 {
1020 let x = rng.gen_u128() % Q;
1021 println!("x={}", x);
1022 let mut should_be = vec![0; Q as usize];
1023 should_be[x as usize] = 1;
1024
1025 let res = eval_plain(&c, &util::u128_to_bits(x, nbits), &[]).unwrap();
1026
1027 for (i, y) in res.into_iter().enumerate() {
1028 if i as u128 == x {
1029 assert_eq!(y, 1);
1030 } else {
1031 assert_eq!(y, 0);
1032 }
1033 }
1034 }
1035 }
1036 }