fancy_garbling/fancy/
binary.rs

1use crate::{
2    FancyBinary,
3    errors::FancyError,
4    fancy::{
5        HasModulus,
6        bundle::{Bundle, BundleGadgets},
7    },
8    util,
9};
10use itertools::Itertools;
11use std::ops::{Deref, DerefMut};
12
13/// Bundle which is explicitly binary representation.
14#[derive(Clone)]
15pub struct BinaryBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> BinaryBundle<W> {
18    /// Create a new binary bundle from a vector of wires.
19    pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
20        BinaryBundle(Bundle::new(ws))
21    }
22
23    /// Extract the underlying bundle from this binary bundle.
24    pub fn extract(self) -> Bundle<W> {
25        self.0
26    }
27}
28
29impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
30    type Target = Bundle<W>;
31
32    fn deref(&self) -> &Bundle<W> {
33        &self.0
34    }
35}
36
37impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
38    fn deref_mut(&mut self) -> &mut Bundle<W> {
39        &mut self.0
40    }
41}
42
43impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
44    fn from(b: Bundle<W>) -> BinaryBundle<W> {
45        debug_assert!(b.moduli().iter().all(|&p| p == 2));
46        BinaryBundle(b)
47    }
48}
49
50impl<F: FancyBinary> BinaryGadgets for F {}
51
52/// Extension trait for `Fancy` providing gadgets that operate over bundles of mod2 wires.
53pub trait BinaryGadgets: FancyBinary + BundleGadgets {
54    /// Create a constant bundle using base 2 inputs.
55    fn bin_constant_bundle(
56        &mut self,
57        val: u128,
58        nbits: usize,
59    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
60        self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits])
61            .map(BinaryBundle)
62    }
63
64    /// Output a binary bundle and interpret the result as a `u128`.
65    fn bin_output(&mut self, x: &BinaryBundle<Self::Item>) -> Result<Option<u128>, Self::Error> {
66        Ok(self.output_bundle(x)?.map(|bs| util::u128_from_bits(&bs)))
67    }
68
69    /// Output a slice of binary bundles and interpret the results as a `u128`.
70    fn bin_outputs(
71        &mut self,
72        xs: &[BinaryBundle<Self::Item>],
73    ) -> Result<Option<Vec<u128>>, Self::Error> {
74        let mut zs = Vec::with_capacity(xs.len());
75        for x in xs.iter() {
76            let z = self.bin_output(x)?;
77            zs.push(z);
78        }
79        Ok(zs.into_iter().collect())
80    }
81
82    /// Xor the bits of two bundles together pairwise.
83    fn bin_xor(
84        &mut self,
85        x: &BinaryBundle<Self::Item>,
86        y: &BinaryBundle<Self::Item>,
87    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
88        x.wires()
89            .iter()
90            .zip(y.wires().iter())
91            .map(|(x, y)| self.xor(x, y))
92            .collect::<Result<Vec<Self::Item>, Self::Error>>()
93            .map(BinaryBundle::new)
94    }
95
96    /// And the bits of two bundles together pairwise.
97    fn bin_and(
98        &mut self,
99        x: &BinaryBundle<Self::Item>,
100        y: &BinaryBundle<Self::Item>,
101    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
102        x.wires()
103            .iter()
104            .zip(y.wires().iter())
105            .map(|(x, y)| self.and(x, y))
106            .collect::<Result<Vec<Self::Item>, Self::Error>>()
107            .map(BinaryBundle::new)
108    }
109
110    /// Or the bits of two bundles together pairwise.
111    fn bin_or(
112        &mut self,
113        x: &BinaryBundle<Self::Item>,
114        y: &BinaryBundle<Self::Item>,
115    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
116        x.wires()
117            .iter()
118            .zip(y.wires().iter())
119            .map(|(x, y)| self.or(x, y))
120            .collect::<Result<Vec<Self::Item>, Self::Error>>()
121            .map(BinaryBundle::new)
122    }
123
124    /// Binary addition. Returns the result and the carry.
125    fn bin_addition(
126        &mut self,
127        xs: &BinaryBundle<Self::Item>,
128        ys: &BinaryBundle<Self::Item>,
129    ) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
130        if xs.moduli() != ys.moduli() {
131            return Err(Self::Error::from(FancyError::UnequalModuli));
132        }
133        let xwires = xs.wires();
134        let ywires = ys.wires();
135        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
136        let mut bs = vec![z];
137        for i in 1..xwires.len() {
138            let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
139            z = res.0;
140            c = res.1;
141            bs.push(z);
142        }
143        Ok((BinaryBundle::new(bs), c))
144    }
145
146    /// Binary addition. Avoids creating extra gates for the final carry.
147    fn bin_addition_no_carry(
148        &mut self,
149        xs: &BinaryBundle<Self::Item>,
150        ys: &BinaryBundle<Self::Item>,
151    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
152        if xs.moduli() != ys.moduli() {
153            return Err(Self::Error::from(FancyError::UnequalModuli));
154        }
155        let xwires = xs.wires();
156        let ywires = ys.wires();
157        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
158        let mut bs = vec![z];
159        for i in 1..xwires.len() - 1 {
160            let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
161            z = res.0;
162            c = res.1;
163            bs.push(z);
164        }
165        // xor instead of add
166        z = self.xor_many(&[
167            xwires.last().unwrap().clone(),
168            ywires.last().unwrap().clone(),
169            c,
170        ])?;
171        bs.push(z);
172        Ok(BinaryBundle::new(bs))
173    }
174
175    /// Binary multiplication.
176    ///
177    /// Returns the lower-order half of the output bits, ie a number with the same number
178    /// of bits as the inputs.
179    fn bin_multiplication_lower_half(
180        &mut self,
181        xs: &BinaryBundle<Self::Item>,
182        ys: &BinaryBundle<Self::Item>,
183    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
184        if xs.moduli() != ys.moduli() {
185            return Err(Self::Error::from(FancyError::UnequalModuli));
186        }
187
188        let xwires = xs.wires();
189        let ywires = ys.wires();
190
191        let mut sum = xwires
192            .iter()
193            .map(|x| self.and(x, &ywires[0]))
194            .collect::<Result<Vec<Self::Item>, Self::Error>>()
195            .map(BinaryBundle::new)?;
196
197        for i in 1..xwires.len() {
198            let mul = xwires
199                .iter()
200                .map(|x| self.and(x, &ywires[i]))
201                .collect::<Result<Vec<Self::Item>, Self::Error>>()
202                .map(BinaryBundle::new)?;
203            let shifted = self.shift(&mul, i).map(BinaryBundle)?;
204            sum = self.bin_addition_no_carry(&sum, &shifted)?;
205        }
206
207        Ok(sum)
208    }
209
210    /// Full multiplier
211    fn bin_mul(
212        &mut self,
213        xs: &BinaryBundle<Self::Item>,
214        ys: &BinaryBundle<Self::Item>,
215    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
216        if xs.moduli() != ys.moduli() {
217            return Err(Self::Error::from(FancyError::UnequalModuli));
218        }
219
220        let xwires = xs.wires();
221        let ywires = ys.wires();
222
223        let mut sum = xwires
224            .iter()
225            .map(|x| self.and(x, &ywires[0]))
226            .collect::<Result<_, _>>()
227            .map(BinaryBundle::new)?;
228
229        let zero = self.constant(0, 2)?;
230        sum.pad(&zero, 1);
231
232        for i in 1..xwires.len() {
233            let mul = xwires
234                .iter()
235                .map(|x| self.and(x, &ywires[i]))
236                .collect::<Result<_, _>>()
237                .map(BinaryBundle::new)?;
238            let shifted = self.shift_extend(&mul, i).map(BinaryBundle::from)?;
239            let res = self.bin_addition(&sum, &shifted)?;
240            sum = res.0;
241            sum.push(res.1);
242        }
243
244        Ok(sum)
245    }
246
247    /// Divider
248    fn bin_div(
249        &mut self,
250        xs: &BinaryBundle<Self::Item>,
251        ys: &BinaryBundle<Self::Item>,
252    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
253        if xs.moduli() != ys.moduli() {
254            return Err(Self::Error::from(FancyError::UnequalModuli));
255        }
256        let ys_neg = self.bin_twos_complement(ys)?;
257        let mut acc = self.bin_constant_bundle(0, xs.size())?;
258        let mut qs = BinaryBundle::new(Vec::new());
259        for x in xs.iter().rev() {
260            acc.pop();
261            acc.insert(0, x.clone());
262            let (res, cout) = self.bin_addition(&acc, &ys_neg)?;
263            acc = self.bin_multiplex(&cout, &acc, &res)?;
264            qs.push(cout);
265        }
266        qs.reverse(); // Switch back to little-endian
267        Ok(qs)
268    }
269
270    /// Compute the twos complement of the input bundle (which must be base 2).
271    fn bin_twos_complement(
272        &mut self,
273        xs: &BinaryBundle<Self::Item>,
274    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
275        let not_xs = xs
276            .wires()
277            .iter()
278            .map(|x| self.negate(x))
279            .collect::<Result<Vec<Self::Item>, Self::Error>>()
280            .map(BinaryBundle::new)?;
281        let one = self.bin_constant_bundle(1, xs.size())?;
282        self.bin_addition_no_carry(&not_xs, &one)
283    }
284
285    /// Subtract two binary bundles. Returns the result and whether it underflowed.
286    ///
287    /// Due to the way that `twos_complement(0) = 0`, underflow indicates `y != 0 && x >= y`.
288    fn bin_subtraction(
289        &mut self,
290        xs: &BinaryBundle<Self::Item>,
291        ys: &BinaryBundle<Self::Item>,
292    ) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
293        let neg_ys = self.bin_twos_complement(ys)?;
294        self.bin_addition(xs, &neg_ys)
295    }
296
297    /// If `x=0` return `c1` as a bundle of constant bits, else return `c2`.
298    fn bin_multiplex_constant_bits(
299        &mut self,
300        x: &Self::Item,
301        c1: u128,
302        c2: u128,
303        nbits: usize,
304    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
305        let c1_bs = util::u128_to_bits(c1, nbits)
306            .into_iter()
307            .map(|x: u16| x > 0)
308            .collect_vec();
309        let c2_bs = util::u128_to_bits(c2, nbits)
310            .into_iter()
311            .map(|x: u16| x > 0)
312            .collect_vec();
313        c1_bs
314            .into_iter()
315            .zip(c2_bs.into_iter())
316            .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2))
317            .collect::<Result<Vec<Self::Item>, Self::Error>>()
318            .map(BinaryBundle::new)
319    }
320
321    /// Multiplex gadget for binary bundles
322    fn bin_multiplex(
323        &mut self,
324        b: &Self::Item,
325        x: &BinaryBundle<Self::Item>,
326        y: &BinaryBundle<Self::Item>,
327    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
328        x.wires()
329            .iter()
330            .zip(y.wires().iter())
331            .map(|(xwire, ywire)| self.mux(b, xwire, ywire))
332            .collect::<Result<Vec<Self::Item>, Self::Error>>()
333            .map(BinaryBundle::new)
334    }
335
336    /// Write the constant in binary and that gives you the shift amounts, Eg.. 7x is 4x+2x+x.
337    fn bin_cmul(
338        &mut self,
339        x: &BinaryBundle<Self::Item>,
340        c: u128,
341        nbits: usize,
342    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
343        let zero = self.bin_constant_bundle(0, nbits)?;
344        util::u128_to_bits(c, nbits)
345            .into_iter()
346            .enumerate()
347            .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
348            .fold(Ok(zero), |z, shift_amt| {
349                let s = self.shift(x, shift_amt).map(BinaryBundle)?;
350                self.bin_addition_no_carry(&(z?), &s)
351            })
352    }
353
354    /// Compute the absolute value of a binary bundle.
355    fn bin_abs(
356        &mut self,
357        x: &BinaryBundle<Self::Item>,
358    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
359        let sign = x.wires().last().unwrap();
360        let negated = self.bin_twos_complement(x)?;
361        self.bin_multiplex(sign, x, &negated)
362    }
363
364    /// Returns 1 if `x < y` (signed version)
365    fn bin_lt_signed(
366        &mut self,
367        x: &BinaryBundle<Self::Item>,
368        y: &BinaryBundle<Self::Item>,
369    ) -> Result<Self::Item, Self::Error> {
370        // determine whether x and y are positive or negative
371        let x_neg = &x.wires().last().unwrap();
372        let y_neg = &y.wires().last().unwrap();
373        let x_pos = self.negate(x_neg)?;
374        let y_pos = self.negate(y_neg)?;
375
376        // broken into cases based on x and y being negative or positive
377        // base case: if x and y have the same sign - use unsigned lt
378        let x_lt_y_unsigned = self.bin_lt(x, y)?;
379
380        // if x is negative and y is positive then x < y
381        let tru = self.constant(1, 2)?;
382        let x_neg_y_pos = self.and(x_neg, &y_pos)?;
383        let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru)?;
384
385        // if x is positive and y is negative then !(x < y)
386        let fls = self.constant(0, 2)?;
387        let x_pos_y_neg = self.and(&x_pos, y_neg)?;
388        self.mux(&x_pos_y_neg, &r2, &fls)
389    }
390
391    /// Returns 1 if `x < y`.
392    fn bin_lt(
393        &mut self,
394        x: &BinaryBundle<Self::Item>,
395        y: &BinaryBundle<Self::Item>,
396    ) -> Result<Self::Item, Self::Error> {
397        // underflow indicates y != 0 && x >= y
398        // requiring special care to remove the y != 0, which is what follows.
399        let (_, lhs) = self.bin_subtraction(x, y)?;
400
401        // Now we build a clause equal to (y == 0 || x >= y), which we can OR with
402        // lhs to remove the y==0 aspect.
403        // check if y==0
404        let y_contains_1 = self.or_many(y.wires())?;
405        let y_eq_0 = self.negate(&y_contains_1)?;
406
407        // if x != 0, then x >= y, ... assuming x is not negative
408        let x_contains_1 = self.or_many(x.wires())?;
409
410        // y == 0 && x >= y
411        let rhs = self.and(&y_eq_0, &x_contains_1)?;
412
413        // (y != 0 && x >= y) || (y == 0 && x >= y)
414        // => x >= y && (y != 0 || y == 0)\
415        // => x >= y && 1
416        // => x >= y
417        let geq = self.or(&lhs, &rhs)?;
418        let ngeq = self.negate(&geq)?;
419
420        let xy_neq_0 = self.or(&y_contains_1, &x_contains_1)?;
421        self.and(&xy_neq_0, &ngeq)
422    }
423
424    /// Returns 1 if `x >= y`.
425    fn bin_geq(
426        &mut self,
427        x: &BinaryBundle<Self::Item>,
428        y: &BinaryBundle<Self::Item>,
429    ) -> Result<Self::Item, Self::Error> {
430        let z = self.bin_lt(x, y)?;
431        self.negate(&z)
432    }
433
434    /// Compute the maximum bundle in `xs`.
435    fn bin_max(
436        &mut self,
437        xs: &[BinaryBundle<Self::Item>],
438    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
439        if xs.is_empty() {
440            return Err(Self::Error::from(FancyError::InvalidArgNum {
441                got: xs.len(),
442                needed: 1,
443            }));
444        }
445        xs.iter().skip(1).fold(Ok(xs[0].clone()), |x, y| {
446            x.map(|x| {
447                let pos = self.bin_lt(&x, y)?;
448                let neg = self.negate(&pos)?;
449                x.wires()
450                    .iter()
451                    .zip(y.wires().iter())
452                    .map(|(x, y)| {
453                        let xp = self.and(x, &neg)?;
454                        let yp = self.and(y, &pos)?;
455                        self.xor(&xp, &yp)
456                    })
457                    .collect::<Result<Vec<Self::Item>, Self::Error>>()
458                    .map(BinaryBundle::new)
459            })?
460        })
461    }
462
463    /// Demux a binary bundle into a unary vector.
464    fn bin_demux(&mut self, x: &BinaryBundle<Self::Item>) -> Result<Vec<Self::Item>, Self::Error> {
465        let wires = x.wires();
466        let nbits = wires.len();
467        if nbits > 8 {
468            return Err(Self::Error::from(FancyError::InvalidArg(
469                "wire bitlength too large".to_string(),
470            )));
471        }
472
473        let mut outs = Vec::with_capacity(1 << nbits);
474
475        for ix in 0..1 << nbits {
476            let mut acc = wires[0].clone();
477            if (ix & 1) == 0 {
478                acc = self.negate(&acc)?;
479            }
480            for (i, w) in wires.iter().enumerate().skip(1) {
481                if ((ix >> i) & 1) > 0 {
482                    acc = self.and(&acc, w)?;
483                } else {
484                    let not_w = self.negate(w)?;
485                    acc = self.and(&acc, &not_w)?;
486                }
487            }
488            outs.push(acc);
489        }
490
491        Ok(outs)
492    }
493
494    /// arithmetic right shift (shifts the sign of the MSB into the new spaces)
495    fn bin_rsa(
496        &mut self,
497        x: &BinaryBundle<Self::Item>,
498        c: usize,
499    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
500        self.bin_shr(x, c, x.wires().last().unwrap())
501    }
502
503    /// logical right shift (shifts 0 into the empty spaces)
504    fn bin_rsl(
505        &mut self,
506        x: &BinaryBundle<Self::Item>,
507        c: usize,
508    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
509        let zero = self.constant(0, 2)?;
510        self.bin_shr(x, c, &zero)
511    }
512
513    /// shift a value right by a constant, filling space on the right by `pad`
514    fn bin_shr(
515        &mut self,
516        x: &BinaryBundle<Self::Item>,
517        c: usize,
518        pad: &Self::Item,
519    ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
520        let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
521
522        for i in 0..x.wires().len() {
523            let src_idx = i + c;
524            if src_idx >= x.wires().len() {
525                wires.push(pad.clone())
526            } else {
527                wires.push(x.wires()[src_idx].clone())
528            }
529        }
530
531        Ok(BinaryBundle::new(wires))
532    }
533    /// Compute `x == y` for binary bundles.
534    fn bin_eq_bundles(
535        &mut self,
536        x: &BinaryBundle<Self::Item>,
537        y: &BinaryBundle<Self::Item>,
538    ) -> Result<Self::Item, Self::Error> {
539        // compute (x^y == 0) for each residue
540        let zs = x
541            .wires()
542            .iter()
543            .zip_eq(y.wires().iter())
544            .map(|(x, y)| {
545                let xy = self.xor(x, y)?;
546                self.negate(&xy)
547            })
548            .collect::<Result<Vec<Self::Item>, Self::Error>>()?;
549        // and_many will return 1 only if all outputs of xnor are 1
550        // indicating equality
551        self.and_many(&zs)
552    }
553}