fancy_garbling/fancy/
binary.rs

1use crate::{
2    FancyBinary,
3    fancy::{
4        HasModulus,
5        bundle::{Bundle, BundleGadgets},
6    },
7    util,
8};
9use itertools::Itertools;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use std::ops::{Deref, DerefMut};
13use swanky_channel::Channel;
14
15/// Bundle which is explicitly binary representation.
16#[derive(Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct BinaryBundle<W>(Bundle<W>);
19
20impl<W: Clone + HasModulus> BinaryBundle<W> {
21    /// Create a new binary bundle from a vector of wires.
22    pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
23        BinaryBundle(Bundle::new(ws))
24    }
25
26    /// Extract the underlying bundle from this binary bundle.
27    pub fn extract(self) -> Bundle<W> {
28        self.0
29    }
30}
31
32impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
33    type Target = Bundle<W>;
34
35    fn deref(&self) -> &Bundle<W> {
36        &self.0
37    }
38}
39
40impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
41    fn deref_mut(&mut self) -> &mut Bundle<W> {
42        &mut self.0
43    }
44}
45
46impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
47    fn from(b: Bundle<W>) -> BinaryBundle<W> {
48        debug_assert!(b.moduli().iter().all(|&p| p == 2));
49        BinaryBundle(b)
50    }
51}
52
53impl<F: FancyBinary> BinaryGadgets for F {}
54
55/// Extension trait for `Fancy` providing gadgets that operate over bundles of mod2 wires.
56pub trait BinaryGadgets: FancyBinary + BundleGadgets {
57    /// Create a constant bundle using base 2 inputs.
58    fn bin_constant_bundle(
59        &mut self,
60        val: u128,
61        nbits: usize,
62        channel: &mut Channel,
63    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
64        self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits], channel)
65            .map(BinaryBundle)
66    }
67
68    /// Output a binary bundle and interpret the result as a `u128`.
69    fn bin_output(
70        &mut self,
71        x: &BinaryBundle<Self::Item>,
72        channel: &mut Channel,
73    ) -> swanky_error::Result<Option<u128>> {
74        Ok(self
75            .output_bundle(x, channel)?
76            .map(|bs| util::u128_from_bits(&bs)))
77    }
78
79    /// Output a slice of binary bundles and interpret the results as a `u128`.
80    fn bin_outputs(
81        &mut self,
82        xs: &[BinaryBundle<Self::Item>],
83        channel: &mut Channel,
84    ) -> swanky_error::Result<Option<Vec<u128>>> {
85        let mut zs = Vec::with_capacity(xs.len());
86        for x in xs.iter() {
87            let z = self.bin_output(x, channel)?;
88            zs.push(z);
89        }
90        Ok(zs.into_iter().collect())
91    }
92
93    /// Xor the bits of two bundles together pairwise.
94    fn bin_xor(
95        &mut self,
96        x: &BinaryBundle<Self::Item>,
97        y: &BinaryBundle<Self::Item>,
98    ) -> BinaryBundle<Self::Item> {
99        BinaryBundle::new(
100            x.wires()
101                .iter()
102                .zip(y.wires().iter())
103                .map(|(x, y)| self.xor(x, y))
104                .collect::<Vec<Self::Item>>(),
105        )
106    }
107
108    /// And the bits of two bundles together pairwise.
109    fn bin_and(
110        &mut self,
111        x: &BinaryBundle<Self::Item>,
112        y: &BinaryBundle<Self::Item>,
113        channel: &mut Channel,
114    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
115        x.wires()
116            .iter()
117            .zip(y.wires().iter())
118            .map(|(x, y)| self.and(x, y, channel))
119            .collect::<swanky_error::Result<Vec<Self::Item>>>()
120            .map(BinaryBundle::new)
121    }
122
123    /// Or the bits of two bundles together pairwise.
124    fn bin_or(
125        &mut self,
126        x: &BinaryBundle<Self::Item>,
127        y: &BinaryBundle<Self::Item>,
128        channel: &mut Channel,
129    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
130        x.wires()
131            .iter()
132            .zip(y.wires().iter())
133            .map(|(x, y)| self.or(x, y, channel))
134            .collect::<swanky_error::Result<Vec<Self::Item>>>()
135            .map(BinaryBundle::new)
136    }
137
138    /// Binary addition. Returns the result and the carry.
139    ///
140    /// # Panics
141    /// This panics if `xs` and `ys` do not have equal moduli.
142    fn bin_addition(
143        &mut self,
144        xs: &BinaryBundle<Self::Item>,
145        ys: &BinaryBundle<Self::Item>,
146        channel: &mut Channel,
147    ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
148        assert_eq!(xs.moduli(), ys.moduli());
149        let xwires = xs.wires();
150        let ywires = ys.wires();
151        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
152        let mut bs = vec![z];
153        for i in 1..xwires.len() {
154            let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
155            z = res.0;
156            c = res.1;
157            bs.push(z);
158        }
159        Ok((BinaryBundle::new(bs), c))
160    }
161
162    /// Binary addition. Avoids creating extra gates for the final carry.
163    ///
164    /// # Panics
165    /// This panics if `xs` and `ys` do not have equal moduli.
166    fn bin_addition_no_carry(
167        &mut self,
168        xs: &BinaryBundle<Self::Item>,
169        ys: &BinaryBundle<Self::Item>,
170        channel: &mut Channel,
171    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
172        assert_eq!(xs.moduli(), ys.moduli());
173        let xwires = xs.wires();
174        let ywires = ys.wires();
175        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
176        let mut bs = vec![z];
177        for i in 1..xwires.len() - 1 {
178            let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
179            z = res.0;
180            c = res.1;
181            bs.push(z);
182        }
183        // xor instead of add
184        z = self.xor_many(&[
185            xwires.last().unwrap().clone(),
186            ywires.last().unwrap().clone(),
187            c,
188        ]);
189        bs.push(z);
190        Ok(BinaryBundle::new(bs))
191    }
192
193    /// Binary multiplication.
194    ///
195    /// Returns the lower-order half of the output bits, ie a number with the same number
196    /// of bits as the inputs.
197    ///
198    /// # Panics
199    /// This panics if `xs` and `ys` do not have equal moduli.
200    fn bin_multiplication_lower_half(
201        &mut self,
202        xs: &BinaryBundle<Self::Item>,
203        ys: &BinaryBundle<Self::Item>,
204        channel: &mut Channel,
205    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
206        assert_eq!(xs.moduli(), ys.moduli());
207
208        let xwires = xs.wires();
209        let ywires = ys.wires();
210
211        let mut sum = xwires
212            .iter()
213            .map(|x| self.and(x, &ywires[0], channel))
214            .collect::<swanky_error::Result<Vec<Self::Item>>>()
215            .map(BinaryBundle::new)?;
216
217        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
218            let mul = xwires
219                .iter()
220                .map(|x| self.and(x, ywire, channel))
221                .collect::<swanky_error::Result<Vec<Self::Item>>>()
222                .map(BinaryBundle::new)?;
223            let shifted = self.shift(&mul, i, channel).map(BinaryBundle)?;
224            sum = self.bin_addition_no_carry(&sum, &shifted, channel)?;
225        }
226
227        Ok(sum)
228    }
229
230    /// Full multiplier.
231    ///
232    /// # Panics
233    /// This panics if `xs` and `ys` do not have equal moduli.
234    fn bin_mul(
235        &mut self,
236        xs: &BinaryBundle<Self::Item>,
237        ys: &BinaryBundle<Self::Item>,
238        channel: &mut Channel,
239    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
240        assert_eq!(xs.moduli(), ys.moduli());
241
242        let xwires = xs.wires();
243        let ywires = ys.wires();
244
245        let mut sum = xwires
246            .iter()
247            .map(|x| self.and(x, &ywires[0], channel))
248            .collect::<Result<_, _>>()
249            .map(BinaryBundle::new)?;
250
251        let zero = self.constant(0, 2, channel)?;
252        sum.pad(&zero, 1);
253
254        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
255            let mul = xwires
256                .iter()
257                .map(|x| self.and(x, ywire, channel))
258                .collect::<Result<_, _>>()
259                .map(BinaryBundle::new)?;
260            let shifted = self
261                .shift_extend(&mul, i, channel)
262                .map(BinaryBundle::from)?;
263            let res = self.bin_addition(&sum, &shifted, channel)?;
264            sum = res.0;
265            sum.push(res.1);
266        }
267
268        Ok(sum)
269    }
270
271    /// Divider.
272    ///
273    /// # Panics
274    /// This panics if `xs` and `ys` do not have equal moduli.
275    fn bin_div(
276        &mut self,
277        xs: &BinaryBundle<Self::Item>,
278        ys: &BinaryBundle<Self::Item>,
279        channel: &mut Channel,
280    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
281        assert_eq!(xs.moduli(), ys.moduli());
282        let ys_neg = self.bin_twos_complement(ys, channel)?;
283        let mut acc = self.bin_constant_bundle(0, xs.size(), channel)?;
284        let mut qs = BinaryBundle::new(Vec::new());
285        for x in xs.iter().rev() {
286            acc.pop();
287            acc.insert(0, x.clone());
288            let (res, cout) = self.bin_addition(&acc, &ys_neg, channel)?;
289            acc = self.bin_multiplex(&cout, &acc, &res, channel)?;
290            qs.push(cout);
291        }
292        qs.reverse(); // Switch back to little-endian
293        Ok(qs)
294    }
295
296    /// Compute the twos complement of the input bundle (which must be base 2).
297    fn bin_twos_complement(
298        &mut self,
299        xs: &BinaryBundle<Self::Item>,
300        channel: &mut Channel,
301    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
302        let not_xs = BinaryBundle::new(
303            xs.wires()
304                .iter()
305                .map(|x| self.negate(x))
306                .collect::<Vec<_>>(),
307        );
308        let one = self.bin_constant_bundle(1, xs.size(), channel)?;
309        self.bin_addition_no_carry(&not_xs, &one, channel)
310    }
311
312    /// Subtract two binary bundles. Returns the result and whether it underflowed.
313    ///
314    /// Due to the way that `twos_complement(0) = 0`, underflow indicates `y != 0 && x >= y`.
315    fn bin_subtraction(
316        &mut self,
317        xs: &BinaryBundle<Self::Item>,
318        ys: &BinaryBundle<Self::Item>,
319        channel: &mut Channel,
320    ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
321        let neg_ys = self.bin_twos_complement(ys, channel)?;
322        self.bin_addition(xs, &neg_ys, channel)
323    }
324
325    /// If `x=0` return `c1` as a bundle of constant bits, else return `c2`.
326    fn bin_multiplex_constant_bits(
327        &mut self,
328        x: &Self::Item,
329        c1: u128,
330        c2: u128,
331        nbits: usize,
332        channel: &mut Channel,
333    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
334        let c1_bs = util::u128_to_bits(c1, nbits)
335            .into_iter()
336            .map(|x: u16| x > 0)
337            .collect_vec();
338        let c2_bs = util::u128_to_bits(c2, nbits)
339            .into_iter()
340            .map(|x: u16| x > 0)
341            .collect_vec();
342        c1_bs
343            .into_iter()
344            .zip(c2_bs)
345            .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2, channel))
346            .collect::<swanky_error::Result<Vec<Self::Item>>>()
347            .map(BinaryBundle::new)
348    }
349
350    /// Multiplex gadget for binary bundles
351    fn bin_multiplex(
352        &mut self,
353        b: &Self::Item,
354        x: &BinaryBundle<Self::Item>,
355        y: &BinaryBundle<Self::Item>,
356        channel: &mut Channel,
357    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
358        x.wires()
359            .iter()
360            .zip(y.wires().iter())
361            .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
362            .collect::<swanky_error::Result<Vec<Self::Item>>>()
363            .map(BinaryBundle::new)
364    }
365
366    /// Write the constant in binary and that gives you the shift amounts, Eg.. 7x is 4x+2x+x.
367    fn bin_cmul(
368        &mut self,
369        x: &BinaryBundle<Self::Item>,
370        c: u128,
371        nbits: usize,
372        channel: &mut Channel,
373    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
374        let zero = self.bin_constant_bundle(0, nbits, channel)?;
375        util::u128_to_bits(c, nbits)
376            .into_iter()
377            .enumerate()
378            .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
379            .try_fold(zero, |z, shift_amt| {
380                let s = self.shift(x, shift_amt, channel).map(BinaryBundle)?;
381                self.bin_addition_no_carry(&z, &s, channel)
382            })
383    }
384
385    /// Compute the absolute value of a binary bundle.
386    fn bin_abs(
387        &mut self,
388        x: &BinaryBundle<Self::Item>,
389        channel: &mut Channel,
390    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
391        let sign = x.wires().last().unwrap();
392        let negated = self.bin_twos_complement(x, channel)?;
393        self.bin_multiplex(sign, x, &negated, channel)
394    }
395
396    /// Returns 1 if `x < y` (signed version)
397    fn bin_lt_signed(
398        &mut self,
399        x: &BinaryBundle<Self::Item>,
400        y: &BinaryBundle<Self::Item>,
401        channel: &mut Channel,
402    ) -> swanky_error::Result<Self::Item> {
403        // determine whether x and y are positive or negative
404        let x_neg = &x.wires().last().unwrap();
405        let y_neg = &y.wires().last().unwrap();
406        let x_pos = self.negate(x_neg);
407        let y_pos = self.negate(y_neg);
408
409        // broken into cases based on x and y being negative or positive
410        // base case: if x and y have the same sign - use unsigned lt
411        let x_lt_y_unsigned = self.bin_lt(x, y, channel)?;
412
413        // if x is negative and y is positive then x < y
414        let tru = self.constant(1, 2, channel)?;
415        let x_neg_y_pos = self.and(x_neg, &y_pos, channel)?;
416        let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru, channel)?;
417
418        // if x is positive and y is negative then !(x < y)
419        let fls = self.constant(0, 2, channel)?;
420        let x_pos_y_neg = self.and(&x_pos, y_neg, channel)?;
421        self.mux(&x_pos_y_neg, &r2, &fls, channel)
422    }
423
424    /// Returns 1 if `x < y`.
425    fn bin_lt(
426        &mut self,
427        x: &BinaryBundle<Self::Item>,
428        y: &BinaryBundle<Self::Item>,
429        channel: &mut Channel,
430    ) -> swanky_error::Result<Self::Item> {
431        // underflow indicates y != 0 && x >= y
432        // requiring special care to remove the y != 0, which is what follows.
433        let (_, lhs) = self.bin_subtraction(x, y, channel)?;
434
435        // Now we build a clause equal to (y == 0 || x >= y), which we can OR with
436        // lhs to remove the y==0 aspect.
437        // check if y==0
438        let y_contains_1 = self.or_many(y.wires(), channel)?;
439        let y_eq_0 = self.negate(&y_contains_1);
440
441        // if x != 0, then x >= y, ... assuming x is not negative
442        let x_contains_1 = self.or_many(x.wires(), channel)?;
443
444        // y == 0 && x >= y
445        let rhs = self.and(&y_eq_0, &x_contains_1, channel)?;
446
447        // (y != 0 && x >= y) || (y == 0 && x >= y)
448        // => x >= y && (y != 0 || y == 0)\
449        // => x >= y && 1
450        // => x >= y
451        let geq = self.or(&lhs, &rhs, channel)?;
452        let ngeq = self.negate(&geq);
453
454        let xy_neq_0 = self.or(&y_contains_1, &x_contains_1, channel)?;
455        self.and(&xy_neq_0, &ngeq, channel)
456    }
457
458    /// Returns 1 if `x >= y`.
459    fn bin_geq(
460        &mut self,
461        x: &BinaryBundle<Self::Item>,
462        y: &BinaryBundle<Self::Item>,
463        channel: &mut Channel,
464    ) -> swanky_error::Result<Self::Item> {
465        let z = self.bin_lt(x, y, channel)?;
466        Ok(self.negate(&z))
467    }
468
469    /// Compute the maximum bundle in `xs`.
470    ///
471    /// # Panics
472    /// Panics if `xs` is empty.
473    fn bin_max(
474        &mut self,
475        xs: &[BinaryBundle<Self::Item>],
476        channel: &mut Channel,
477    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
478        assert!(!xs.is_empty(), "`xs` cannot be empty");
479        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
480            let pos = self.bin_lt(&x, y, channel)?;
481            let neg = self.negate(&pos);
482            Ok(BinaryBundle::new(
483                x.wires()
484                    .iter()
485                    .zip(y.wires().iter())
486                    .map(|(x, y)| {
487                        let xp = self.and(x, &neg, channel)?;
488                        let yp = self.and(y, &pos, channel)?;
489                        Ok(self.xor(&xp, &yp))
490                    })
491                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
492            ))
493        })
494    }
495
496    /// Demux a binary bundle into a unary vector.
497    ///
498    /// # Panics
499    /// Panics if the length of `x` is greater than eight.
500    fn bin_demux(
501        &mut self,
502        x: &BinaryBundle<Self::Item>,
503        channel: &mut Channel,
504    ) -> swanky_error::Result<Vec<Self::Item>> {
505        let wires = x.wires();
506        let nbits = wires.len();
507        assert!(nbits <= 8, "wire bitlength is too large");
508
509        let mut outs = Vec::with_capacity(1 << nbits);
510
511        for ix in 0..1 << nbits {
512            let mut acc = wires[0].clone();
513            if (ix & 1) == 0 {
514                acc = self.negate(&acc);
515            }
516            for (i, w) in wires.iter().enumerate().skip(1) {
517                if ((ix >> i) & 1) > 0 {
518                    acc = self.and(&acc, w, channel)?;
519                } else {
520                    let not_w = self.negate(w);
521                    acc = self.and(&acc, &not_w, channel)?;
522                }
523            }
524            outs.push(acc);
525        }
526
527        Ok(outs)
528    }
529
530    /// arithmetic right shift (shifts the sign of the MSB into the new spaces)
531    fn bin_rsa(&mut self, x: &BinaryBundle<Self::Item>, c: usize) -> BinaryBundle<Self::Item> {
532        self.bin_shr(x, c, x.wires().last().unwrap())
533    }
534
535    /// logical right shift (shifts 0 into the empty spaces)
536    fn bin_rsl(
537        &mut self,
538        x: &BinaryBundle<Self::Item>,
539        c: usize,
540        channel: &mut Channel,
541    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
542        let zero = self.constant(0, 2, channel)?;
543        Ok(self.bin_shr(x, c, &zero))
544    }
545
546    /// shift a value right by a constant, filling space on the right by `pad`
547    fn bin_shr(
548        &mut self,
549        x: &BinaryBundle<Self::Item>,
550        c: usize,
551        pad: &Self::Item,
552    ) -> BinaryBundle<Self::Item> {
553        let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
554
555        for i in 0..x.wires().len() {
556            let src_idx = i + c;
557            if src_idx >= x.wires().len() {
558                wires.push(pad.clone())
559            } else {
560                wires.push(x.wires()[src_idx].clone())
561            }
562        }
563
564        BinaryBundle::new(wires)
565    }
566    /// Compute `x == y` for binary bundles.
567    fn bin_eq_bundles(
568        &mut self,
569        x: &BinaryBundle<Self::Item>,
570        y: &BinaryBundle<Self::Item>,
571        channel: &mut Channel,
572    ) -> swanky_error::Result<Self::Item> {
573        // compute (x^y == 0) for each residue
574        let zs = x
575            .wires()
576            .iter()
577            .zip_eq(y.wires().iter())
578            .map(|(x, y)| {
579                let xy = self.xor(x, y);
580                self.negate(&xy)
581            })
582            .collect::<Vec<_>>();
583        // and_many will return 1 only if all outputs of xnor are 1
584        // indicating equality
585        self.and_many(&zs, channel)
586    }
587}