fancy_garbling/twopac/semihonest/
mod.rs

1//! Implementation of semi-honest two-party computation.
2
3mod evaluator;
4mod garbler;
5
6pub use evaluator::Evaluator;
7pub use garbler::Garbler;
8
9#[cfg(test)]
10mod tests {
11    use super::*;
12    use crate::{
13        AllWire, CrtBundle, CrtGadgets, FancyArithmetic, FancyBinary, FancyInput, WireLabel,
14        WireMod2,
15        circuit::{BinaryCircuit, CircuitInfo, EvaluableCircuit, eval_plain},
16        dummy::Dummy,
17        util::RngExt,
18    };
19    use itertools::Itertools;
20    use swanky_aes_rng::AesRng;
21    use swanky_channel_legacy::{UnixChannel, unix_channel_pair};
22    use swanky_ot_chou_orlandi::{Receiver as ChouOrlandiReceiver, Sender as ChouOrlandiSender};
23
24    fn addition<F: FancyArithmetic>(
25        f: &mut F,
26        a: &F::Item,
27        b: &F::Item,
28    ) -> Result<Option<u16>, F::Error> {
29        let c = f.add(a, b)?;
30        f.output(&c)
31    }
32
33    #[test]
34    fn test_addition_circuit() {
35        for a in 0..2 {
36            for b in 0..2 {
37                let (sender, receiver) = unix_channel_pair();
38                std::thread::spawn(move || {
39                    let rng = AesRng::new();
40                    let mut gb = Garbler::<UnixChannel, AesRng, ChouOrlandiSender, AllWire>::new(
41                        sender, rng,
42                    )
43                    .unwrap();
44                    let x = gb.encode(a, 3).unwrap();
45                    let ys = gb.receive_many(&[3]).unwrap();
46                    addition(&mut gb, &x, &ys[0]).unwrap();
47                });
48                let rng = AesRng::new();
49                let mut ev = Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, AllWire>::new(
50                    receiver, rng,
51                )
52                .unwrap();
53                let x = ev.receive(3).unwrap();
54                let ys = ev.encode_many(&[b], &[3]).unwrap();
55                let output = addition(&mut ev, &x, &ys[0]).unwrap().unwrap();
56                assert_eq!((a + b) % 3, output);
57            }
58        }
59    }
60
61    fn relu<F: FancyArithmetic + FancyBinary>(
62        b: &mut F,
63        xs: &[CrtBundle<F::Item>],
64    ) -> Option<Vec<u128>> {
65        let mut outputs = Vec::new();
66        for x in xs.iter() {
67            let q = x.composite_modulus();
68            let c = b.crt_constant_bundle(1, q).unwrap();
69            let y = b.crt_mul(x, &c).unwrap();
70            let z = b.crt_relu(&y, "100%", None).unwrap();
71            outputs.push(b.crt_output(&z).unwrap());
72        }
73        outputs.into_iter().collect()
74    }
75
76    #[test]
77    fn test_relu() {
78        let mut rng = rand::thread_rng();
79        let n = 10;
80        let ps = crate::util::primes_with_width(10);
81        let q = crate::util::product(&ps);
82        let input = (0..n).map(|_| rng.gen_u128() % q).collect::<Vec<u128>>();
83
84        // Run dummy version.
85        let mut dummy = Dummy::new();
86        let dummy_input = input
87            .iter()
88            .map(|x| dummy.crt_encode(*x, q).unwrap())
89            .collect_vec();
90        let target = relu(&mut dummy, &dummy_input).unwrap();
91
92        // Run 2PC version.
93        let (sender, receiver) = unix_channel_pair();
94        std::thread::spawn(move || {
95            let rng = AesRng::new();
96            let mut gb =
97                Garbler::<UnixChannel, AesRng, ChouOrlandiSender, AllWire>::new(sender, rng)
98                    .unwrap();
99            let xs = gb.crt_encode_many(&input, q).unwrap();
100            relu(&mut gb, &xs);
101        });
102
103        let rng = AesRng::new();
104        let mut ev =
105            Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, AllWire>::new(receiver, rng)
106                .unwrap();
107        let xs = ev.crt_receive_many(n, q).unwrap();
108        let result = relu(&mut ev, &xs).unwrap();
109        assert_eq!(target, result);
110    }
111
112    type GB<Wire> = Garbler<UnixChannel, AesRng, ChouOrlandiSender, Wire>;
113    type EV<Wire> = Evaluator<UnixChannel, AesRng, ChouOrlandiReceiver, Wire>;
114
115    fn test_circuit<CIRC, Wire: WireLabel>(circ: CIRC)
116    where
117        CIRC: EvaluableCircuit<Dummy>
118            + EvaluableCircuit<GB<Wire>>
119            + EvaluableCircuit<EV<Wire>>
120            + CircuitInfo
121            + Send
122            + 'static,
123    {
124        circ.print_info().unwrap();
125
126        let circ_ = circ.clone();
127        let (sender, receiver) = unix_channel_pair();
128        let handle = std::thread::spawn(move || {
129            let rng = AesRng::new();
130            let mut gb =
131                Garbler::<UnixChannel, AesRng, ChouOrlandiSender, Wire>::new(sender, rng).unwrap();
132            let xs = gb.encode_many(&vec![0_u16; 128], &vec![2; 128]).unwrap();
133            let ys = gb.receive_many(&vec![2; 128]).unwrap();
134            circ_.eval(&mut gb, &xs, &ys).unwrap();
135        });
136        let rng = AesRng::new();
137        let mut ev =
138            Evaluator::<UnixChannel, AesRng, ChouOrlandiReceiver, Wire>::new(receiver, rng)
139                .unwrap();
140        let xs = ev.receive_many(&vec![2; 128]).unwrap();
141        let ys = ev.encode_many(&vec![0_u16; 128], &vec![2; 128]).unwrap();
142        let out = circ.eval(&mut ev, &xs, &ys).unwrap().unwrap();
143        handle.join().unwrap();
144
145        let target = eval_plain(&circ, &vec![0_u16; 128], &vec![0_u16; 128]).unwrap();
146        assert_eq!(out, target);
147    }
148
149    #[test]
150    fn test_aes_arithmetic() {
151        let circ = BinaryCircuit::parse(std::io::Cursor::<&'static [u8]>::new(include_bytes!(
152            "../../../circuits/AES-non-expanded.txt"
153        )))
154        .unwrap();
155        test_circuit::<_, AllWire>(circ);
156    }
157
158    #[test]
159    fn test_aes_binary() {
160        let circ = BinaryCircuit::parse(std::io::Cursor::<&'static [u8]>::new(include_bytes!(
161            "../../../circuits/AES-non-expanded.txt"
162        )))
163        .unwrap();
164        test_circuit::<_, WireMod2>(circ);
165    }
166}