fancy_garbling/twopac/semihonest/
mod.rs1mod evaluator;
4mod garbler;
5
6pub use evaluator::Evaluator;
7pub use garbler::Garbler;
8
9#[cfg(test)]
10mod tests {
11 use super::*;
12 use crate::{
13 AllWire, CrtBundle, CrtGadgets, FancyArithmetic, FancyBinary, FancyInput, WireLabel,
14 WireMod2,
15 circuit::{BinaryCircuit, CircuitInfo, EvaluableCircuit, eval_plain},
16 dummy::Dummy,
17 util::RngExt,
18 };
19 use itertools::Itertools;
20 use swanky_aes_rng::AesRng;
21 use swanky_channel_legacy::{UnixChannel, unix_channel_pair};
22 use swanky_ot_chou_orlandi::{Receiver as ChouOrlandiReceiver, Sender as ChouOrlandiSender};
23
24 fn addition<F: FancyArithmetic>(
25 f: &mut F,
26 a: &F::Item,
27 b: &F::Item,
28 ) -> Result<Option<u16>, F::Error> {
29 let c = f.add(a, b)?;
30 f.output(&c)
31 }
32
33 #[test]
34 fn test_addition_circuit() {
35 for a in 0..2 {
36 for b in 0..2 {
37 let (sender, receiver) = unix_channel_pair();
38 std::thread::spawn(move || {
39 let rng = AesRng::new();
40 let mut gb = Garbler::<UnixChannel, AesRng, ChouOrlandiSender, AllWire>::new(
41 sender, rng,
42 )
43 .unwrap();
44 let x = gb.encode(a, 3).unwrap();
45 let ys = gb.receive_many(&[3]).unwrap();
46 addition(&mut gb, &x, &ys[0]).unwrap();
47 });
48 let rng = AesRng::new();
49 let mut ev = Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, AllWire>::new(
50 receiver, rng,
51 )
52 .unwrap();
53 let x = ev.receive(3).unwrap();
54 let ys = ev.encode_many(&[b], &[3]).unwrap();
55 let output = addition(&mut ev, &x, &ys[0]).unwrap().unwrap();
56 assert_eq!((a + b) % 3, output);
57 }
58 }
59 }
60
61 fn relu<F: FancyArithmetic + FancyBinary>(
62 b: &mut F,
63 xs: &[CrtBundle<F::Item>],
64 ) -> Option<Vec<u128>> {
65 let mut outputs = Vec::new();
66 for x in xs.iter() {
67 let q = x.composite_modulus();
68 let c = b.crt_constant_bundle(1, q).unwrap();
69 let y = b.crt_mul(x, &c).unwrap();
70 let z = b.crt_relu(&y, "100%", None).unwrap();
71 outputs.push(b.crt_output(&z).unwrap());
72 }
73 outputs.into_iter().collect()
74 }
75
76 #[test]
77 fn test_relu() {
78 let mut rng = rand::thread_rng();
79 let n = 10;
80 let ps = crate::util::primes_with_width(10);
81 let q = crate::util::product(&ps);
82 let input = (0..n).map(|_| rng.gen_u128() % q).collect::<Vec<u128>>();
83
84 let mut dummy = Dummy::new();
86 let dummy_input = input
87 .iter()
88 .map(|x| dummy.crt_encode(*x, q).unwrap())
89 .collect_vec();
90 let target = relu(&mut dummy, &dummy_input).unwrap();
91
92 let (sender, receiver) = unix_channel_pair();
94 std::thread::spawn(move || {
95 let rng = AesRng::new();
96 let mut gb =
97 Garbler::<UnixChannel, AesRng, ChouOrlandiSender, AllWire>::new(sender, rng)
98 .unwrap();
99 let xs = gb.crt_encode_many(&input, q).unwrap();
100 relu(&mut gb, &xs);
101 });
102
103 let rng = AesRng::new();
104 let mut ev =
105 Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, AllWire>::new(receiver, rng)
106 .unwrap();
107 let xs = ev.crt_receive_many(n, q).unwrap();
108 let result = relu(&mut ev, &xs).unwrap();
109 assert_eq!(target, result);
110 }
111
112 type GB<Wire> = Garbler<UnixChannel, AesRng, ChouOrlandiSender, Wire>;
113 type EV<Wire> = Evaluator<UnixChannel, AesRng, ChouOrlandiReceiver, Wire>;
114
115 fn test_circuit<CIRC, Wire: WireLabel>(circ: CIRC)
116 where
117 CIRC: EvaluableCircuit<Dummy>
118 + EvaluableCircuit<GB<Wire>>
119 + EvaluableCircuit<EV<Wire>>
120 + CircuitInfo
121 + Send
122 + 'static,
123 {
124 circ.print_info().unwrap();
125
126 let circ_ = circ.clone();
127 let (sender, receiver) = unix_channel_pair();
128 let handle = std::thread::spawn(move || {
129 let rng = AesRng::new();
130 let mut gb =
131 Garbler::<UnixChannel, AesRng, ChouOrlandiSender, Wire>::new(sender, rng).unwrap();
132 let xs = gb.encode_many(&vec![0_u16; 128], &vec![2; 128]).unwrap();
133 let ys = gb.receive_many(&vec![2; 128]).unwrap();
134 circ_.eval(&mut gb, &xs, &ys).unwrap();
135 });
136 let rng = AesRng::new();
137 let mut ev =
138 Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, Wire>::new(receiver, rng)
139 .unwrap();
140 let xs = ev.receive_many(&vec![2; 128]).unwrap();
141 let ys = ev.encode_many(&vec![0_u16; 128], &vec![2; 128]).unwrap();
142 let out = circ.eval(&mut ev, &xs, &ys).unwrap().unwrap();
143 handle.join().unwrap();
144
145 let target = eval_plain(&circ, &vec![0_u16; 128], &vec![0_u16; 128]).unwrap();
146 assert_eq!(out, target);
147 }
148
149 #[test]
150 fn test_aes_arithmetic() {
151 let circ = BinaryCircuit::parse(std::io::Cursor::<&'static [u8]>::new(include_bytes!(
152 "../../../circuits/AES-non-expanded.txt"
153 )))
154 .unwrap();
155 test_circuit::<_, AllWire>(circ);
156 }
157
158 #[test]
159 fn test_aes_binary() {
160 let circ = BinaryCircuit::parse(std::io::Cursor::<&'static [u8]>::new(include_bytes!(
161 "../../../circuits/AES-non-expanded.txt"
162 )))
163 .unwrap();
164 test_circuit::<_, WireMod2>(circ);
165 }
166}