fancy_garbling/garble/
garbler.rs

1use crate::{
2    AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, FancyInput, HasModulus, WireLabel,
3    WireMod2, check_binary,
4    fancy::{BinaryBundle, CrtBundle, Fancy, FancyReveal},
5    garble::binary_and::BinaryWireLabel,
6    hash_wires,
7    util::{RngExt, output_tweak, tweak, tweak2},
8};
9use rand::{CryptoRng, RngCore};
10#[cfg(feature = "serde")]
11use serde::de::DeserializeOwned;
12use std::collections::HashMap;
13use swanky_channel::Channel;
14
15use super::security_warning::warn_proj;
16
17/// Streams garbled circuit ciphertexts through a callback.
18pub struct Garbler<RNG, Wire> {
19    // Zero wirelabel used for binary negation.
20    zero: Wire,
21    // Map from modulus to associated delta wirelabel.
22    deltas: HashMap<u16, Wire>,
23    current_output: usize,
24    current_gate: usize,
25    rng: RNG,
26}
27
28#[cfg(feature = "serde")]
29impl<RNG: CryptoRng + RngCore, Wire: WireLabel + DeserializeOwned> Garbler<RNG, Wire> {
30    /// Load pre-chosen deltas from a file
31    pub fn load_deltas(&mut self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
32        let f = std::fs::File::open(filename)?;
33        let reader = std::io::BufReader::new(f);
34        let deltas: HashMap<u16, Wire> = serde_json::from_reader(reader)?;
35        self.deltas.extend(deltas);
36        Ok(())
37    }
38}
39
40impl<RNG: CryptoRng + RngCore, Wire: WireLabel> Garbler<RNG, Wire> {
41    /// Create a new [`Garbler`].
42    pub fn new(mut rng: RNG, channel: &mut Channel) -> swanky_error::Result<Self> {
43        let zero = Wire::rand(&mut rng, 2);
44        let delta = Wire::rand_delta(&mut rng, 2);
45        let one = zero.clone() + delta.clone();
46        let mut deltas = HashMap::new();
47        deltas.insert(2, delta);
48        // Send the one wirelabel to the evaluator. This is used to make binary
49        // negation free.
50        channel.write(&one.to_repr())?;
51        Ok(Garbler {
52            zero,
53            deltas,
54            current_gate: 0,
55            current_output: 0,
56            rng,
57        })
58    }
59
60    /// The current non-free gate index of the garbling computation
61    fn current_gate(&mut self) -> usize {
62        let current = self.current_gate;
63        self.current_gate += 1;
64        current
65    }
66
67    /// Create a delta if it has not been created yet for this modulus, otherwise just
68    /// return the existing one.
69    pub fn delta(&mut self, q: u16) -> Wire {
70        if let Some(delta) = self.deltas.get(&q) {
71            return delta.clone();
72        }
73        let w = Wire::rand_delta(&mut self.rng, q);
74        self.deltas.insert(q, w.clone());
75        w
76    }
77
78    /// The current output index of the garbling computation.
79    fn current_output(&mut self) -> usize {
80        let current = self.current_output;
81        self.current_output += 1;
82        current
83    }
84
85    /// Get the deltas, consuming the Garbler.
86    ///
87    /// This is useful for reusing wires in multiple garbled circuit instances.
88    pub fn get_deltas(self) -> HashMap<u16, Wire> {
89        self.deltas
90    }
91
92    /// Send a wire over the established channel.
93    pub fn send_wire(&mut self, wire: &Wire, channel: &mut Channel) -> swanky_error::Result<()> {
94        channel.write(&wire.to_repr())?;
95        Ok(())
96    }
97
98    /// Encode a wire, producing the zero wire as well as the encoded value.
99    pub fn encode_wire(&mut self, val: u16, modulus: u16) -> (Wire, Wire) {
100        let zero = Wire::rand(&mut self.rng, modulus);
101        let delta = self.delta(modulus);
102        let enc = zero.clone() + delta * val;
103        (zero, enc)
104    }
105
106    /// Encode many wires, producing zero wires as well as encoded values.
107    ///
108    /// # Panics
109    /// Panics if the length of `vals` and `moduli` are not equal.
110    pub fn encode_many_wires(&mut self, vals: &[u16], moduli: &[u16]) -> (Vec<Wire>, Vec<Wire>) {
111        assert_eq!(vals.len(), moduli.len());
112
113        let mut gbs = Vec::with_capacity(vals.len());
114        let mut evs = Vec::with_capacity(vals.len());
115        for (x, q) in vals.iter().zip(moduli.iter()) {
116            let (gb, ev) = self.encode_wire(*x, *q);
117            gbs.push(gb);
118            evs.push(ev);
119        }
120        (gbs, evs)
121    }
122
123    /// Encode a `CrtBundle`, producing zero wires as well as encoded values.
124    pub fn crt_encode_wire(
125        &mut self,
126        val: u128,
127        modulus: u128,
128    ) -> (CrtBundle<Wire>, CrtBundle<Wire>) {
129        let ms = crate::util::factor(modulus);
130        let xs = crate::util::crt(val, &ms);
131        let (gbs, evs) = self.encode_many_wires(&xs, &ms);
132        (CrtBundle::new(gbs), CrtBundle::new(evs))
133    }
134
135    /// Encode a `BinaryBundle`, producing zero wires as well as encoded values.
136    pub fn bin_encode_wire(
137        &mut self,
138        val: u128,
139        nbits: usize,
140    ) -> (BinaryBundle<Wire>, BinaryBundle<Wire>) {
141        let xs = crate::util::u128_to_bits(val, nbits);
142        let ms = vec![2; nbits];
143        let (gbs, evs) = self.encode_many_wires(&xs, &ms);
144        (BinaryBundle::new(gbs), BinaryBundle::new(evs))
145    }
146}
147
148impl<RNG: CryptoRng + RngCore, Wire: WireLabel> FancyInput for Garbler<RNG, Wire> {
149    type Item = Wire;
150
151    fn encode_many(
152        &mut self,
153        values: &[u16],
154        moduli: &[u16],
155        channel: &mut Channel,
156    ) -> swanky_error::Result<Vec<Self::Item>> {
157        let (zero, encoded) = self.encode_many_wires(values, moduli);
158        for wire in encoded {
159            channel.write(&wire.to_repr())?;
160        }
161        Ok(zero)
162    }
163
164    fn receive_many(
165        &mut self,
166        _moduli: &[u16],
167        _: &mut Channel,
168    ) -> swanky_error::Result<Vec<Self::Item>> {
169        unimplemented!("Garbler cannot receive values")
170    }
171}
172
173impl<RNG: RngCore + CryptoRng, Wire: WireLabel> FancyReveal for Garbler<RNG, Wire> {
174    fn reveal(&mut self, x: &Wire, channel: &mut Channel) -> swanky_error::Result<u16> {
175        // The evaluator needs our cooperation in order to see the output.
176        // Hence, we call output() ourselves.
177        self.output(x, channel)?;
178        let val = channel.read::<u16>()?;
179        Ok(val)
180    }
181}
182
183impl<RNG: RngCore + CryptoRng, W: BinaryWireLabel> FancyBinary for Garbler<RNG, W> {
184    fn and(
185        &mut self,
186        A: &Self::Item,
187        B: &Self::Item,
188        channel: &mut Channel,
189    ) -> swanky_error::Result<Self::Item> {
190        let delta = self.delta(2);
191        let gate_num = self.current_gate();
192        let (gate0, gate1, C) = W::garble_and_gate(gate_num, A, B, &delta);
193        channel.write(&gate0)?;
194        channel.write(&gate1)?;
195        Ok(C)
196    }
197
198    fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
199        *x + *y
200    }
201
202    /// We can negate by having garbler xor wire with Delta
203    ///
204    /// Since we treat all garbler wires as zero,
205    /// xoring with delta conceptually negates the value of the wire
206    fn negate(&mut self, x: &Self::Item) -> Self::Item {
207        let zero = self.zero;
208        self.xor(&zero, x)
209    }
210}
211
212impl<RNG: RngCore + CryptoRng> FancyBinary for Garbler<RNG, AllWire> {
213    /// We can negate by having garbler xor wire with Delta
214    ///
215    /// Since we treat all garbler wires as zero,
216    /// xoring with delta conceptually negates the value of the wire
217    fn negate(&mut self, x: &Self::Item) -> Self::Item {
218        check_binary!(x);
219
220        let zero = self.zero.clone();
221        self.xor(&zero, x)
222    }
223
224    /// Xor is just addition
225    fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
226        check_binary!(x);
227        check_binary!(y);
228
229        self.add(x, y)
230    }
231
232    /// Use binary and_gate
233    fn and(
234        &mut self,
235        x: &Self::Item,
236        y: &Self::Item,
237        channel: &mut Channel,
238    ) -> swanky_error::Result<Self::Item> {
239        if let (AllWire::Mod2(A), AllWire::Mod2(B), AllWire::Mod2(ref delta)) =
240            (x, y, self.delta(2))
241        {
242            let gate_num = self.current_gate();
243            let (gate0, gate1, C) = WireMod2::garble_and_gate(gate_num, A, B, delta);
244            channel.write(&gate0)?;
245            channel.write(&gate1)?;
246            return Ok(AllWire::Mod2(C));
247        }
248        // If we got here, one of the wires isn't binary
249        check_binary!(x);
250        check_binary!(y);
251
252        // Shouldn't be reachable, unless the wire has modulus 2 but is not AllWire::Mod2()
253        unreachable!()
254    }
255}
256
257impl<RNG: RngCore + CryptoRng, Wire: WireLabel + ArithmeticWire> FancyArithmetic
258    for Garbler<RNG, Wire>
259{
260    fn add(&mut self, x: &Wire, y: &Wire) -> Wire {
261        assert_eq!(x.modulus(), y.modulus());
262        x.clone() + y.clone()
263    }
264
265    fn sub(&mut self, x: &Wire, y: &Wire) -> Wire {
266        assert_eq!(x.modulus(), y.modulus());
267        x.clone() - y.clone()
268    }
269
270    fn cmul(&mut self, x: &Wire, c: u16) -> Wire {
271        x.clone() * c
272    }
273
274    fn mul(&mut self, A: &Wire, B: &Wire, channel: &mut Channel) -> swanky_error::Result<Wire> {
275        if A.modulus() < B.modulus() {
276            return self.mul(B, A, channel);
277        }
278
279        let q = A.modulus();
280        let qb = B.modulus();
281        let gate_num = self.current_gate();
282
283        let D = self.delta(q);
284        let Db = self.delta(qb);
285
286        let r;
287        let mut gate = vec![Default::default(); q as usize + qb as usize - 2];
288
289        // hack for unequal moduli
290        if q != qb {
291            // would need to pack minitable into more than one u128 to support qb > 8
292            assert!(
293                qb <= 8,
294                "`B.modulus()` with asymmetric moduli is capped at 8"
295            );
296
297            r = self.rng.gen_u16() % q;
298            let t = tweak2(gate_num as u64, 1);
299
300            let mut minitable = vec![u128::default(); qb as usize];
301            let mut B_ = B.clone();
302            for b in 0..qb {
303                if b > 0 {
304                    B_ += Db.clone();
305                }
306                let new_color = ((r + b) % q) as u128;
307                let ct = (u128::from(B_.hash(t)) & 0xFFFF) ^ new_color;
308                minitable[B_.color() as usize] = ct;
309            }
310
311            let mut packed = 0;
312            for (i, item) in minitable.iter().enumerate().take(qb as usize) {
313                packed += item << (16 * i);
314            }
315            gate.push(packed.into());
316        } else {
317            r = B.color(); // secret value known only to the garbler (ev knows r+b)
318        }
319
320        let g = tweak2(gate_num as u64, 0);
321
322        // X = H(A+aD) + arD such that a + A.color == 0
323        let alpha = (q - A.color()) % q; // alpha = -A.color
324        let X1 = A.clone() + D.clone() * alpha;
325
326        // Y = H(B + bD) + (b + r)A such that b + B.color == 0
327        let beta = (qb - B.color()) % qb;
328        let Y1 = B.clone() + Db.clone() * beta;
329
330        let [hashX, hashY] = hash_wires([&X1, &Y1], g);
331
332        let X = Wire::hash_to_mod(hashX, q) + D.clone() * (alpha * r % q);
333        let Y = Wire::hash_to_mod(hashY, q) + A.clone() * ((beta + r) % q);
334
335        let mut precomp = Vec::with_capacity(q as usize);
336        // precompute a lookup table of X.minus(&D_cmul[(a * r % q)])
337        //                            = X.plus(&D_cmul[((q - (a * r % q)) % q)])
338        let mut X_ = X.clone();
339        precomp.push(X_.to_repr());
340        for _ in 1..q {
341            X_ += D.clone();
342            precomp.push(X_.to_repr());
343        }
344
345        // We can vectorize the hashes here too, but then we need to precompute all `q` sums of A
346        // with delta [A, A + D, A + D + D, etc.]
347        // Would probably need another alloc which isn't great
348        let mut A_ = A.clone();
349        for a in 0..q {
350            if a > 0 {
351                A_ += D.clone();
352            }
353            // garbler's half-gate: outputs X-arD
354            // G = H(A+aD) ^ X+a(-r)D = H(A+aD) ^ X-arD
355            if A_.color() != 0 {
356                gate[A_.color() as usize - 1] =
357                    A_.hash(g) ^ precomp[((q - (a * r % q)) % q) as usize];
358            }
359        }
360        precomp.clear();
361
362        // precompute a lookup table of Y.minus(&A_cmul[((b+r) % q)])
363        //                            = Y.plus(&A_cmul[((q - ((b+r) % q)) % q)])
364        let mut Y_ = Y.clone();
365        precomp.push(Y_.to_repr());
366        for _ in 1..q {
367            Y_ += A.clone();
368            precomp.push(Y_.to_repr());
369        }
370
371        // Same note about vectorization as A
372        let mut B_ = B.clone();
373        for b in 0..qb {
374            if b > 0 {
375                B_ += Db.clone();
376            }
377            // evaluator's half-gate: outputs Y-(b+r)D
378            // G = H(B+bD) + Y-(b+r)A
379            if B_.color() != 0 {
380                gate[q as usize - 1 + B_.color() as usize - 1] =
381                    B_.hash(g) ^ precomp[((q - ((b + r) % q)) % q) as usize];
382            }
383        }
384
385        for block in gate.iter() {
386            channel.write(block)?;
387        }
388        Ok(X + Y)
389    }
390
391    fn proj(
392        &mut self,
393        A: &Wire,
394        q_out: u16,
395        tt: Option<Vec<u16>>,
396        channel: &mut Channel,
397    ) -> swanky_error::Result<Wire> {
398        warn_proj();
399        assert!(tt.is_some(), "`tt` must not be `None`");
400        let tt = tt.unwrap();
401
402        let q_in = A.modulus();
403        let mut gate = vec![Default::default(); q_in as usize - 1];
404
405        let tao = A.color();
406        let g = tweak(self.current_gate());
407
408        let Din = self.delta(q_in);
409        let Dout = self.delta(q_out);
410
411        // output zero-wire
412        // W_g^0 <- -H(g, W_{a_1}^0 - \tao\Delta_m) - \phi(-\tao)\Delta_n
413        let C = (A.clone() + Din.clone() * ((q_in - tao) % q_in)).hashback(g, q_out)
414            + Dout.clone() * ((q_out - tt[((q_in - tao) % q_in) as usize]) % q_out);
415
416        // precompute `let C_ = C.plus(&Dout.cmul(tt[x as usize]))`
417        let C_precomputed = {
418            let mut C_ = C.clone();
419            (0..q_out)
420                .map(|x| {
421                    if x > 0 {
422                        C_ += Dout.clone();
423                    }
424                    C_.to_repr()
425                })
426                .collect::<Vec<_>>()
427        };
428
429        let mut A_ = A.clone();
430        for x in 0..q_in {
431            if x > 0 {
432                A_ += Din.clone(); // avoiding expensive cmul for `A_ = A.plus(&Din.cmul(x))`
433            }
434
435            let ix = (tao as usize + x as usize) % q_in as usize;
436            if ix == 0 {
437                continue;
438            }
439
440            let ct = A_.hash(g) ^ C_precomputed[tt[x as usize] as usize];
441            gate[ix - 1] = ct;
442        }
443
444        for block in gate.iter() {
445            channel.write(block)?;
446        }
447        Ok(C)
448    }
449}
450
451impl<RNG: RngCore + CryptoRng, Wire: WireLabel> Fancy for Garbler<RNG, Wire> {
452    type Item = Wire;
453
454    fn constant(&mut self, x: u16, q: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
455        let zero = Wire::rand(&mut self.rng, q);
456        let wire = zero.clone() + self.delta(q) * x;
457        self.send_wire(&wire, channel)?;
458        Ok(zero)
459    }
460
461    fn output(&mut self, X: &Wire, channel: &mut Channel) -> swanky_error::Result<Option<u16>> {
462        let q = X.modulus();
463        let i = self.current_output();
464        let D = self.delta(q);
465        for k in 0..q {
466            let block = (X.clone() + D.clone() * k).hash(output_tweak(i, k));
467            channel.write(&block)?;
468        }
469        Ok(None)
470    }
471}