Skip to main content

fancy_garbling/fancy/
binary.rs

1use crate::{
2    FancyBinary,
3    fancy::{
4        HasModulus,
5        bundle::{Bundle, BundleGadgets},
6    },
7    util,
8};
9use itertools::Itertools;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use std::ops::{Deref, DerefMut};
13use swanky_channel::Channel;
14
15/// Bundle which is explicitly binary representation.
16#[derive(Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct BinaryBundle<W>(Bundle<W>);
19
20impl<W: Clone + HasModulus> BinaryBundle<W> {
21    /// Create a new binary bundle from a vector of wires.
22    pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
23        BinaryBundle(Bundle::new(ws))
24    }
25
26    /// Extract the underlying bundle from this binary bundle.
27    pub fn extract(self) -> Bundle<W> {
28        self.0
29    }
30}
31
32impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
33    type Target = Bundle<W>;
34
35    fn deref(&self) -> &Bundle<W> {
36        &self.0
37    }
38}
39
40impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
41    fn deref_mut(&mut self) -> &mut Bundle<W> {
42        &mut self.0
43    }
44}
45
46impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
47    fn from(b: Bundle<W>) -> BinaryBundle<W> {
48        debug_assert!(b.moduli().iter().all(|&p| p == 2));
49        BinaryBundle(b)
50    }
51}
52
53impl<F: FancyBinary> BinaryGadgets for F {}
54
55/// Extension trait for `Fancy` providing gadgets that operate over bundles of mod2 wires.
56pub trait BinaryGadgets: FancyBinary + BundleGadgets {
57    /// Encode a binary input bundle.
58    fn bin_encode(
59        &mut self,
60        value: u128,
61        nbits: usize,
62        channel: &mut Channel,
63    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
64        let xs = util::u128_to_bits(value, nbits);
65        self.encode_bundle(&xs, &vec![2; nbits], channel)
66            .map(BinaryBundle::from)
67    }
68
69    /// Receive an binary input bundle.
70    fn bin_receive(
71        &mut self,
72        nbits: usize,
73        channel: &mut Channel,
74    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
75        self.receive_bundle(&vec![2; nbits], channel)
76            .map(BinaryBundle::from)
77    }
78
79    /// Encode many binary input bundles.
80    fn bin_encode_many(
81        &mut self,
82        values: &[u128],
83        nbits: usize,
84        channel: &mut Channel,
85    ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
86        let xs = values
87            .iter()
88            .flat_map(|x| util::u128_to_bits(*x, nbits))
89            .collect_vec();
90        let mut wires = self.encode_many(&xs, &vec![2; values.len() * nbits], channel)?;
91        let buns = (0..values.len())
92            .map(|_| {
93                let ws = wires.drain(0..nbits).collect_vec();
94                BinaryBundle::new(ws)
95            })
96            .collect_vec();
97        Ok(buns)
98    }
99
100    /// Receive many binary input bundles.
101    fn bin_receive_many(
102        &mut self,
103        ninputs: usize,
104        nbits: usize,
105        channel: &mut Channel,
106    ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
107        let mut wires = self.receive_many(&vec![2; ninputs * nbits], channel)?;
108        let buns = (0..ninputs)
109            .map(|_| {
110                let ws = wires.drain(0..nbits).collect_vec();
111                BinaryBundle::new(ws)
112            })
113            .collect_vec();
114        Ok(buns)
115    }
116
117    /// Create a constant bundle using base 2 inputs.
118    fn bin_constant_bundle(
119        &mut self,
120        val: u128,
121        nbits: usize,
122        channel: &mut Channel,
123    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
124        self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits], channel)
125            .map(BinaryBundle)
126    }
127
128    /// Output a binary bundle and interpret the result as a `u128`.
129    fn bin_output(
130        &mut self,
131        x: &BinaryBundle<Self::Item>,
132        channel: &mut Channel,
133    ) -> swanky_error::Result<Option<u128>> {
134        Ok(self
135            .output_bundle(x, channel)?
136            .map(|bs| util::u128_from_bits(&bs)))
137    }
138
139    /// Output a slice of binary bundles and interpret the results as a `u128`.
140    fn bin_outputs(
141        &mut self,
142        xs: &[BinaryBundle<Self::Item>],
143        channel: &mut Channel,
144    ) -> swanky_error::Result<Option<Vec<u128>>> {
145        let mut zs = Vec::with_capacity(xs.len());
146        for x in xs.iter() {
147            let z = self.bin_output(x, channel)?;
148            zs.push(z);
149        }
150        Ok(zs.into_iter().collect())
151    }
152
153    /// Xor the bits of two bundles together pairwise.
154    fn bin_xor(
155        &mut self,
156        x: &BinaryBundle<Self::Item>,
157        y: &BinaryBundle<Self::Item>,
158    ) -> BinaryBundle<Self::Item> {
159        BinaryBundle::new(
160            x.wires()
161                .iter()
162                .zip(y.wires().iter())
163                .map(|(x, y)| self.xor(x, y))
164                .collect::<Vec<Self::Item>>(),
165        )
166    }
167
168    /// And the bits of two bundles together pairwise.
169    fn bin_and(
170        &mut self,
171        x: &BinaryBundle<Self::Item>,
172        y: &BinaryBundle<Self::Item>,
173        channel: &mut Channel,
174    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
175        x.wires()
176            .iter()
177            .zip(y.wires().iter())
178            .map(|(x, y)| self.and(x, y, channel))
179            .collect::<swanky_error::Result<Vec<Self::Item>>>()
180            .map(BinaryBundle::new)
181    }
182
183    /// Or the bits of two bundles together pairwise.
184    fn bin_or(
185        &mut self,
186        x: &BinaryBundle<Self::Item>,
187        y: &BinaryBundle<Self::Item>,
188        channel: &mut Channel,
189    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
190        x.wires()
191            .iter()
192            .zip(y.wires().iter())
193            .map(|(x, y)| self.or(x, y, channel))
194            .collect::<swanky_error::Result<Vec<Self::Item>>>()
195            .map(BinaryBundle::new)
196    }
197
198    /// Binary addition. Returns the result and the carry.
199    ///
200    /// # Panics
201    /// This panics if `xs` and `ys` do not have equal moduli.
202    fn bin_addition(
203        &mut self,
204        xs: &BinaryBundle<Self::Item>,
205        ys: &BinaryBundle<Self::Item>,
206        channel: &mut Channel,
207    ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
208        assert_eq!(xs.moduli(), ys.moduli());
209        let xwires = xs.wires();
210        let ywires = ys.wires();
211        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
212        let mut bs = vec![z];
213        for i in 1..xwires.len() {
214            let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
215            z = res.0;
216            c = res.1;
217            bs.push(z);
218        }
219        Ok((BinaryBundle::new(bs), c))
220    }
221
222    /// Binary addition. Avoids creating extra gates for the final carry.
223    ///
224    /// # Panics
225    /// This panics if `xs` and `ys` do not have equal moduli.
226    fn bin_addition_no_carry(
227        &mut self,
228        xs: &BinaryBundle<Self::Item>,
229        ys: &BinaryBundle<Self::Item>,
230        channel: &mut Channel,
231    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
232        assert_eq!(xs.moduli(), ys.moduli());
233        let xwires = xs.wires();
234        let ywires = ys.wires();
235        let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
236        let mut bs = vec![z];
237        for i in 1..xwires.len() - 1 {
238            let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
239            z = res.0;
240            c = res.1;
241            bs.push(z);
242        }
243        // xor instead of add
244        z = self.xor_many(&[
245            xwires.last().unwrap().clone(),
246            ywires.last().unwrap().clone(),
247            c,
248        ]);
249        bs.push(z);
250        Ok(BinaryBundle::new(bs))
251    }
252
253    /// Binary multiplication.
254    ///
255    /// Returns the lower-order half of the output bits, ie a number with the same number
256    /// of bits as the inputs.
257    ///
258    /// # Panics
259    /// This panics if `xs` and `ys` do not have equal moduli.
260    fn bin_multiplication_lower_half(
261        &mut self,
262        xs: &BinaryBundle<Self::Item>,
263        ys: &BinaryBundle<Self::Item>,
264        channel: &mut Channel,
265    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
266        assert_eq!(xs.moduli(), ys.moduli());
267
268        let xwires = xs.wires();
269        let ywires = ys.wires();
270
271        let mut sum = xwires
272            .iter()
273            .map(|x| self.and(x, &ywires[0], channel))
274            .collect::<swanky_error::Result<Vec<Self::Item>>>()
275            .map(BinaryBundle::new)?;
276
277        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
278            let mul = xwires
279                .iter()
280                .map(|x| self.and(x, ywire, channel))
281                .collect::<swanky_error::Result<Vec<Self::Item>>>()
282                .map(BinaryBundle::new)?;
283            let shifted = self.shift(&mul, i, channel).map(BinaryBundle)?;
284            sum = self.bin_addition_no_carry(&sum, &shifted, channel)?;
285        }
286
287        Ok(sum)
288    }
289
290    /// Full multiplier.
291    ///
292    /// # Panics
293    /// This panics if `xs` and `ys` do not have equal moduli.
294    fn bin_mul(
295        &mut self,
296        xs: &BinaryBundle<Self::Item>,
297        ys: &BinaryBundle<Self::Item>,
298        channel: &mut Channel,
299    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
300        assert_eq!(xs.moduli(), ys.moduli());
301
302        let xwires = xs.wires();
303        let ywires = ys.wires();
304
305        let mut sum = xwires
306            .iter()
307            .map(|x| self.and(x, &ywires[0], channel))
308            .collect::<Result<_, _>>()
309            .map(BinaryBundle::new)?;
310
311        let zero = self.constant(0, 2, channel)?;
312        sum.pad(&zero, 1);
313
314        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
315            let mul = xwires
316                .iter()
317                .map(|x| self.and(x, ywire, channel))
318                .collect::<Result<_, _>>()
319                .map(BinaryBundle::new)?;
320            let shifted = self
321                .shift_extend(&mul, i, channel)
322                .map(BinaryBundle::from)?;
323            let res = self.bin_addition(&sum, &shifted, channel)?;
324            sum = res.0;
325            sum.push(res.1);
326        }
327
328        Ok(sum)
329    }
330
331    /// Divider.
332    ///
333    /// # Panics
334    /// This panics if `xs` and `ys` do not have equal moduli.
335    fn bin_div(
336        &mut self,
337        xs: &BinaryBundle<Self::Item>,
338        ys: &BinaryBundle<Self::Item>,
339        channel: &mut Channel,
340    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
341        assert_eq!(xs.moduli(), ys.moduli());
342        let ys_neg = self.bin_twos_complement(ys, channel)?;
343        let mut acc = self.bin_constant_bundle(0, xs.size(), channel)?;
344        let mut qs = BinaryBundle::new(Vec::new());
345        for x in xs.iter().rev() {
346            acc.pop();
347            acc.insert(0, x.clone());
348            let (res, cout) = self.bin_addition(&acc, &ys_neg, channel)?;
349            acc = self.bin_multiplex(&cout, &acc, &res, channel)?;
350            qs.push(cout);
351        }
352        qs.reverse(); // Switch back to little-endian
353        Ok(qs)
354    }
355
356    /// Compute the twos complement of the input bundle (which must be base 2).
357    fn bin_twos_complement(
358        &mut self,
359        xs: &BinaryBundle<Self::Item>,
360        channel: &mut Channel,
361    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
362        let not_xs = BinaryBundle::new(
363            xs.wires()
364                .iter()
365                .map(|x| self.negate(x))
366                .collect::<Vec<_>>(),
367        );
368        let one = self.bin_constant_bundle(1, xs.size(), channel)?;
369        self.bin_addition_no_carry(&not_xs, &one, channel)
370    }
371
372    /// Subtract two binary bundles. Returns the result and whether it underflowed.
373    ///
374    /// Due to the way that `twos_complement(0) = 0`, underflow indicates `y != 0 && x >= y`.
375    fn bin_subtraction(
376        &mut self,
377        xs: &BinaryBundle<Self::Item>,
378        ys: &BinaryBundle<Self::Item>,
379        channel: &mut Channel,
380    ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
381        let neg_ys = self.bin_twos_complement(ys, channel)?;
382        self.bin_addition(xs, &neg_ys, channel)
383    }
384
385    /// If `x=0` return `c1` as a bundle of constant bits, else return `c2`.
386    fn bin_multiplex_constant_bits(
387        &mut self,
388        x: &Self::Item,
389        c1: u128,
390        c2: u128,
391        nbits: usize,
392        channel: &mut Channel,
393    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
394        let c1_bs = util::u128_to_bits(c1, nbits)
395            .into_iter()
396            .map(|x: u16| x > 0)
397            .collect_vec();
398        let c2_bs = util::u128_to_bits(c2, nbits)
399            .into_iter()
400            .map(|x: u16| x > 0)
401            .collect_vec();
402        c1_bs
403            .into_iter()
404            .zip(c2_bs)
405            .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2, channel))
406            .collect::<swanky_error::Result<Vec<Self::Item>>>()
407            .map(BinaryBundle::new)
408    }
409
410    /// Multiplex gadget for binary bundles
411    fn bin_multiplex(
412        &mut self,
413        b: &Self::Item,
414        x: &BinaryBundle<Self::Item>,
415        y: &BinaryBundle<Self::Item>,
416        channel: &mut Channel,
417    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
418        x.wires()
419            .iter()
420            .zip(y.wires().iter())
421            .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
422            .collect::<swanky_error::Result<Vec<Self::Item>>>()
423            .map(BinaryBundle::new)
424    }
425
426    /// Write the constant in binary and that gives you the shift amounts, Eg.. 7x is 4x+2x+x.
427    fn bin_cmul(
428        &mut self,
429        x: &BinaryBundle<Self::Item>,
430        c: u128,
431        nbits: usize,
432        channel: &mut Channel,
433    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
434        let zero = self.bin_constant_bundle(0, nbits, channel)?;
435        util::u128_to_bits(c, nbits)
436            .into_iter()
437            .enumerate()
438            .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
439            .try_fold(zero, |z, shift_amt| {
440                let s = self.shift(x, shift_amt, channel).map(BinaryBundle)?;
441                self.bin_addition_no_carry(&z, &s, channel)
442            })
443    }
444
445    /// Compute the absolute value of a binary bundle.
446    fn bin_abs(
447        &mut self,
448        x: &BinaryBundle<Self::Item>,
449        channel: &mut Channel,
450    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
451        let sign = x.wires().last().unwrap();
452        let negated = self.bin_twos_complement(x, channel)?;
453        self.bin_multiplex(sign, x, &negated, channel)
454    }
455
456    /// Returns 1 if `x < y` (signed version)
457    fn bin_lt_signed(
458        &mut self,
459        x: &BinaryBundle<Self::Item>,
460        y: &BinaryBundle<Self::Item>,
461        channel: &mut Channel,
462    ) -> swanky_error::Result<Self::Item> {
463        // determine whether x and y are positive or negative
464        let x_neg = &x.wires().last().unwrap();
465        let y_neg = &y.wires().last().unwrap();
466        let x_pos = self.negate(x_neg);
467        let y_pos = self.negate(y_neg);
468
469        // broken into cases based on x and y being negative or positive
470        // base case: if x and y have the same sign - use unsigned lt
471        let x_lt_y_unsigned = self.bin_lt(x, y, channel)?;
472
473        // if x is negative and y is positive then x < y
474        let tru = self.constant(1, 2, channel)?;
475        let x_neg_y_pos = self.and(x_neg, &y_pos, channel)?;
476        let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru, channel)?;
477
478        // if x is positive and y is negative then !(x < y)
479        let fls = self.constant(0, 2, channel)?;
480        let x_pos_y_neg = self.and(&x_pos, y_neg, channel)?;
481        self.mux(&x_pos_y_neg, &r2, &fls, channel)
482    }
483
484    /// Returns 1 if `x < y`.
485    fn bin_lt(
486        &mut self,
487        x: &BinaryBundle<Self::Item>,
488        y: &BinaryBundle<Self::Item>,
489        channel: &mut Channel,
490    ) -> swanky_error::Result<Self::Item> {
491        // underflow indicates y != 0 && x >= y
492        // requiring special care to remove the y != 0, which is what follows.
493        let (_, lhs) = self.bin_subtraction(x, y, channel)?;
494
495        // Now we build a clause equal to (y == 0 || x >= y), which we can OR with
496        // lhs to remove the y==0 aspect.
497        // check if y==0
498        let y_contains_1 = self.or_many(y.wires(), channel)?;
499        let y_eq_0 = self.negate(&y_contains_1);
500
501        // if x != 0, then x >= y, ... assuming x is not negative
502        let x_contains_1 = self.or_many(x.wires(), channel)?;
503
504        // y == 0 && x >= y
505        let rhs = self.and(&y_eq_0, &x_contains_1, channel)?;
506
507        // (y != 0 && x >= y) || (y == 0 && x >= y)
508        // => x >= y && (y != 0 || y == 0)\
509        // => x >= y && 1
510        // => x >= y
511        let geq = self.or(&lhs, &rhs, channel)?;
512        let ngeq = self.negate(&geq);
513
514        let xy_neq_0 = self.or(&y_contains_1, &x_contains_1, channel)?;
515        self.and(&xy_neq_0, &ngeq, channel)
516    }
517
518    /// Returns 1 if `x >= y`.
519    fn bin_geq(
520        &mut self,
521        x: &BinaryBundle<Self::Item>,
522        y: &BinaryBundle<Self::Item>,
523        channel: &mut Channel,
524    ) -> swanky_error::Result<Self::Item> {
525        let z = self.bin_lt(x, y, channel)?;
526        Ok(self.negate(&z))
527    }
528
529    /// Compute the maximum bundle in `xs`.
530    ///
531    /// # Panics
532    /// Panics if `xs` is empty.
533    fn bin_max(
534        &mut self,
535        xs: &[BinaryBundle<Self::Item>],
536        channel: &mut Channel,
537    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
538        assert!(!xs.is_empty(), "`xs` cannot be empty");
539        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
540            let pos = self.bin_lt(&x, y, channel)?;
541            let neg = self.negate(&pos);
542            Ok(BinaryBundle::new(
543                x.wires()
544                    .iter()
545                    .zip(y.wires().iter())
546                    .map(|(x, y)| {
547                        let xp = self.and(x, &neg, channel)?;
548                        let yp = self.and(y, &pos, channel)?;
549                        Ok(self.xor(&xp, &yp))
550                    })
551                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
552            ))
553        })
554    }
555
556    /// Demux a binary bundle into a unary vector.
557    ///
558    /// # Panics
559    /// Panics if the length of `x` is greater than eight.
560    fn bin_demux(
561        &mut self,
562        x: &BinaryBundle<Self::Item>,
563        channel: &mut Channel,
564    ) -> swanky_error::Result<Vec<Self::Item>> {
565        let wires = x.wires();
566        let nbits = wires.len();
567        assert!(nbits <= 8, "wire bitlength is too large");
568
569        let mut outs = Vec::with_capacity(1 << nbits);
570
571        for ix in 0..1 << nbits {
572            let mut acc = wires[0].clone();
573            if (ix & 1) == 0 {
574                acc = self.negate(&acc);
575            }
576            for (i, w) in wires.iter().enumerate().skip(1) {
577                if ((ix >> i) & 1) > 0 {
578                    acc = self.and(&acc, w, channel)?;
579                } else {
580                    let not_w = self.negate(w);
581                    acc = self.and(&acc, &not_w, channel)?;
582                }
583            }
584            outs.push(acc);
585        }
586
587        Ok(outs)
588    }
589
590    /// arithmetic right shift (shifts the sign of the MSB into the new spaces)
591    fn bin_rsa(&mut self, x: &BinaryBundle<Self::Item>, c: usize) -> BinaryBundle<Self::Item> {
592        self.bin_shr(x, c, x.wires().last().unwrap())
593    }
594
595    /// logical right shift (shifts 0 into the empty spaces)
596    fn bin_rsl(
597        &mut self,
598        x: &BinaryBundle<Self::Item>,
599        c: usize,
600        channel: &mut Channel,
601    ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
602        let zero = self.constant(0, 2, channel)?;
603        Ok(self.bin_shr(x, c, &zero))
604    }
605
606    /// shift a value right by a constant, filling space on the right by `pad`
607    fn bin_shr(
608        &mut self,
609        x: &BinaryBundle<Self::Item>,
610        c: usize,
611        pad: &Self::Item,
612    ) -> BinaryBundle<Self::Item> {
613        let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
614
615        for i in 0..x.wires().len() {
616            let src_idx = i + c;
617            if src_idx >= x.wires().len() {
618                wires.push(pad.clone())
619            } else {
620                wires.push(x.wires()[src_idx].clone())
621            }
622        }
623
624        BinaryBundle::new(wires)
625    }
626    /// Compute `x == y` for binary bundles.
627    fn bin_eq_bundles(
628        &mut self,
629        x: &BinaryBundle<Self::Item>,
630        y: &BinaryBundle<Self::Item>,
631        channel: &mut Channel,
632    ) -> swanky_error::Result<Self::Item> {
633        // compute (x^y == 0) for each residue
634        let zs = x
635            .wires()
636            .iter()
637            .zip_eq(y.wires().iter())
638            .map(|(x, y)| {
639                let xy = self.xor(x, y);
640                self.negate(&xy)
641            })
642            .collect::<Vec<_>>();
643        // and_many will return 1 only if all outputs of xnor are 1
644        // indicating equality
645        self.and_many(&zs, channel)
646    }
647}