Skip to main content

fancy_garbling/fancy/
bundle.rs

1use crate::{
2    FancyArithmetic, FancyBinary, FancyProj,
3    fancy::{Fancy, HasModulus},
4};
5use itertools::Itertools;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::ops::Index;
9use swanky_channel::Channel;
10
11/// A collection of wires, useful for the garbled gadgets defined by `BundleGadgets`.
12#[derive(Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct Bundle<W>(Vec<W>);
15
16impl<W: Clone + HasModulus> Bundle<W> {
17    /// Create a new bundle from some wires.
18    pub fn new(ws: Vec<W>) -> Bundle<W> {
19        Bundle(ws)
20    }
21
22    /// Return the moduli of all the wires in the bundle.
23    pub fn moduli(&self) -> Vec<u16> {
24        self.0.iter().map(HasModulus::modulus).collect()
25    }
26
27    /// Extract the wires from this bundle.
28    pub fn wires(&self) -> &[W] {
29        &self.0
30    }
31
32    /// Get the number of wires in this bundle.
33    pub fn size(&self) -> usize {
34        self.0.len()
35    }
36
37    /// Whether this bundle only contains residues in mod 2.
38    pub fn is_binary(&self) -> bool {
39        self.moduli().iter().all(|m| *m == 2)
40    }
41
42    /// Returns a new bundle only containing wires with matching moduli.
43    pub fn with_moduli(&self, moduli: &[u16]) -> Bundle<W> {
44        let old_ws = self.wires();
45        let mut new_ws = Vec::with_capacity(moduli.len());
46        for &p in moduli {
47            if let Some(w) = old_ws.iter().find(|&x| x.modulus() == p) {
48                new_ws.push(w.clone());
49            } else {
50                panic!("Bundle::with_moduli: no {} modulus in bundle", p);
51            }
52        }
53        Bundle(new_ws)
54    }
55
56    /// Pad the Bundle with val, n times.
57    pub fn pad(&mut self, val: &W, n: usize) {
58        for _ in 0..n {
59            self.0.push(val.clone());
60        }
61    }
62
63    /// Extract a wire from the Bundle, removing it and returning it.
64    pub fn extract(&mut self, wire_index: usize) -> W {
65        self.0.remove(wire_index)
66    }
67
68    /// Insert a wire from the Bundle
69    pub fn insert(&mut self, wire_index: usize, val: W) {
70        self.0.insert(wire_index, val)
71    }
72
73    /// push a wire onto the Bundle.
74    pub fn push(&mut self, val: W) {
75        self.0.push(val);
76    }
77
78    /// Pop a wire from the Bundle.
79    pub fn pop(&mut self) -> Option<W> {
80        self.0.pop()
81    }
82
83    /// Access the underlying iterator
84    pub fn iter(&self) -> std::slice::Iter<'_, W> {
85        self.0.iter()
86    }
87
88    /// Reverse the wires
89    pub fn reverse(&mut self) {
90        self.0.reverse();
91    }
92}
93
94impl<W: Clone + HasModulus> Index<usize> for Bundle<W> {
95    type Output = W;
96
97    fn index(&self, idx: usize) -> &Self::Output {
98        self.0.index(idx)
99    }
100}
101
102impl<F: Fancy> BundleGadgets for F {}
103impl<F: FancyArithmetic> ArithmeticBundleGadgets for F {}
104impl<F: FancyArithmetic + FancyProj> ArithmeticProjBundleGadgets for F {}
105impl<F: FancyBinary> BinaryBundleGadgets for F {}
106
107/// Arithmetic operations on wire bundles, extending the capability of `FancyArithmetic` operating
108/// on individual wires.
109pub trait ArithmeticBundleGadgets: FancyArithmetic {
110    /// Add two wire bundles pairwise, zipping addition.
111    ///
112    /// In CRT this is plain addition. In binary this is xor.
113    ///
114    /// # Panics
115    /// Panics if `x` and `y` are not of the same length.
116    fn add_bundles(
117        &mut self,
118        x: &Bundle<Self::Item>,
119        y: &Bundle<Self::Item>,
120    ) -> Bundle<Self::Item> {
121        assert_eq!(
122            x.wires().len(),
123            y.wires().len(),
124            "`x` and `y` must be the same length"
125        );
126        Bundle::new(
127            x.wires()
128                .iter()
129                .zip(y.wires().iter())
130                .map(|(x, y)| self.add(x, y))
131                .collect::<Vec<Self::Item>>(),
132        )
133    }
134
135    /// Subtract two wire bundles, residue by residue.
136    ///
137    /// In CRT this is plain subtraction. In binary this is `xor`.
138    ///
139    /// # Panics
140    /// Panics if `x` and `y` are not of the same length.
141    fn sub_bundles(
142        &mut self,
143        x: &Bundle<Self::Item>,
144        y: &Bundle<Self::Item>,
145    ) -> Bundle<Self::Item> {
146        assert_eq!(
147            x.wires().len(),
148            y.wires().len(),
149            "`x` and `y` must be the same length"
150        );
151        Bundle::new(
152            x.wires()
153                .iter()
154                .zip(y.wires().iter())
155                .map(|(x, y)| self.sub(x, y))
156                .collect::<Vec<Self::Item>>(),
157        )
158    }
159
160    /// Multiply each wire in `x` with each wire in `y`, pairwise.
161    ///
162    /// In CRT this is plain multiplication. In binary this is `and`.
163    fn mul_bundles(
164        &mut self,
165        x: &Bundle<Self::Item>,
166        y: &Bundle<Self::Item>,
167        channel: &mut Channel,
168    ) -> swanky_error::Result<Bundle<Self::Item>> {
169        x.wires()
170            .iter()
171            .zip(y.wires().iter())
172            .map(|(x, y)| self.mul(x, y, channel))
173            .collect::<swanky_error::Result<_>>()
174            .map(Bundle::new)
175    }
176
177    /// If b=0 then return 0, else return x.
178    fn mask(
179        &mut self,
180        b: &Self::Item,
181        x: &Bundle<Self::Item>,
182        channel: &mut Channel,
183    ) -> swanky_error::Result<Bundle<Self::Item>> {
184        x.wires()
185            .iter()
186            .map(|xwire| self.mul(xwire, b, channel))
187            .collect::<swanky_error::Result<_>>()
188            .map(Bundle)
189    }
190}
191
192/// Arithmetic operations on wire bundles that utilize projection gates.
193pub trait ArithmeticProjBundleGadgets: FancyArithmetic + FancyProj {
194    /// Mixed radix addition.
195    ///
196    /// # Panics
197    /// Panics if `xs` is empty, or the moduli in `xs` are not all equal.
198    fn mixed_radix_addition(
199        &mut self,
200        xs: &[Bundle<Self::Item>],
201        channel: &mut Channel,
202    ) -> swanky_error::Result<Bundle<Self::Item>> {
203        assert!(!xs.is_empty(), "`xs` cannot be empty");
204        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
205
206        let nargs = xs.len();
207        let n = xs[0].wires().len();
208
209        let mut digit_carry = None;
210        let mut carry_carry = None;
211        let mut max_carry = 0;
212
213        let mut res = Vec::with_capacity(n);
214
215        for i in 0..n {
216            // all the ith digits, in one vec
217            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
218
219            // compute the digit -- easy
220            let digit_sum = self.add_many(&ds);
221            let digit = digit_carry.map_or(digit_sum.clone(), |d| self.add(&digit_sum, &d));
222
223            if i < n - 1 {
224                // compute the carries
225                let q = xs[0].wires()[i].modulus();
226                // max_carry currently contains the max carry from the previous iteration
227                let max_val = nargs as u16 * (q - 1) + max_carry;
228                // now it is the max carry of this iteration
229                max_carry = max_val / q;
230
231                let modded_ds = ds
232                    .iter()
233                    .map(|d| self.mod_change(d, max_val + 1, channel))
234                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
235
236                let carry_sum = self.add_many(&modded_ds);
237                // add in the carry from the previous iteration
238                let carry = carry_carry.map_or(carry_sum.clone(), |c| self.add(&carry_sum, &c));
239
240                // carry now contains the carry information, we just have to project it to
241                // the correct moduli for the next iteration
242                let next_mod = xs[0].wires()[i + 1].modulus();
243                let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
244                digit_carry = Some(self.proj(&carry, next_mod, Some(tt), channel)?);
245
246                let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
247
248                if i < n - 2 {
249                    if max_carry < next_mod {
250                        carry_carry = Some(self.mod_change(
251                            digit_carry.as_ref().unwrap(),
252                            next_max_val + 1,
253                            channel,
254                        )?);
255                    } else {
256                        let tt = (0..=max_val).map(|i| i / q).collect_vec();
257                        carry_carry =
258                            Some(self.proj(&carry, next_max_val + 1, Some(tt), channel)?);
259                    }
260                } else {
261                    // next digit is MSB so we dont need carry_carry
262                    carry_carry = None;
263                }
264            } else {
265                digit_carry = None;
266                carry_carry = None;
267            }
268            res.push(digit);
269        }
270        Ok(Bundle(res))
271    }
272
273    /// Mixed radix addition only returning the MSB.
274    ///
275    /// # Panics
276    /// Panics if `xs` is empty, or the moduli in `xs` are not all equal.
277    fn mixed_radix_addition_msb_only(
278        &mut self,
279        xs: &[Bundle<Self::Item>],
280        channel: &mut Channel,
281    ) -> swanky_error::Result<Self::Item> {
282        assert!(!xs.is_empty(), "`xs` cannot be empty");
283        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
284
285        let nargs = xs.len();
286        let n = xs[0].wires().len();
287
288        let mut opt_carry = None;
289        let mut max_carry = 0;
290
291        for i in 0..n - 1 {
292            // all the ith digits, in one vec
293            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
294            // compute the carry
295            let q = xs[0].moduli()[i];
296            // max_carry currently contains the max carry from the previous iteration
297            let max_val = nargs as u16 * (q - 1) + max_carry;
298            // now it is the max carry of this iteration
299            max_carry = max_val / q;
300
301            // mod change the digits to the max sum possible plus the max carry of the
302            // previous iteration
303            let modded_ds = ds
304                .iter()
305                .map(|d| self.mod_change(d, max_val + 1, channel))
306                .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
307            // add them up
308            let sum = self.add_many(&modded_ds);
309            // add in the carry
310            let sum_with_carry = opt_carry
311                .as_ref()
312                .map_or(sum.clone(), |c| self.add(&sum, c));
313
314            // carry now contains the carry information, we just have to project it to
315            // the correct moduli for the next iteration. It will either be used to
316            // compute the next carry, if i < n-2, or it will be used to compute the
317            // output MSB, in which case it should be the modulus of the SB
318            let next_mod = if i < n - 2 {
319                nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
320            } else {
321                xs[0].moduli()[i + 1] // we will be adding the carry to the MSB
322            };
323
324            let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
325            opt_carry = Some(self.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
326        }
327
328        // compute the msb
329        let ds = xs.iter().map(|x| x.wires()[n - 1].clone()).collect_vec();
330        let digit_sum = self.add_many(&ds);
331        Ok(opt_carry
332            .as_ref()
333            .map_or(digit_sum.clone(), |d| self.add(&digit_sum, d)))
334    }
335
336    /// Compute `x == y`. Returns a wire encoding the result mod 2.
337    ///
338    /// # Panics
339    /// Panics if `x` and `y` do not have equal moduli.
340    fn eq_bundles(
341        &mut self,
342        x: &Bundle<Self::Item>,
343        y: &Bundle<Self::Item>,
344        channel: &mut Channel,
345    ) -> swanky_error::Result<Self::Item> {
346        assert_eq!(x.moduli(), y.moduli());
347
348        let wlen = x.wires().len() as u16;
349        let zs = x
350            .wires()
351            .iter()
352            .zip_eq(y.wires().iter())
353            .map(|(x, y)| {
354                // compute (x-y == 0) for each residue
355                let z = self.sub(x, y);
356                let mut eq_zero_tab = vec![0; x.modulus() as usize];
357                eq_zero_tab[0] = 1;
358                self.proj(&z, wlen + 1, Some(eq_zero_tab), channel)
359            })
360            .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
361        // add up the results, and output whether they equal zero or not, mod 2
362        let z = self.add_many(&zs);
363        let b = zs.len();
364        let mut tab = vec![0; b + 1];
365        tab[b] = 1;
366        self.proj(&z, 2, Some(tab), channel)
367    }
368}
369
370/// Binary operations on wire bundles, extending the capability of `FancyBinary` operating
371/// on individual wires.
372pub trait BinaryBundleGadgets: FancyBinary {
373    /// If b=0 then return x, else return y.
374    fn multiplex(
375        &mut self,
376        b: &Self::Item,
377        x: &Bundle<Self::Item>,
378        y: &Bundle<Self::Item>,
379        channel: &mut Channel,
380    ) -> swanky_error::Result<Bundle<Self::Item>> {
381        x.wires()
382            .iter()
383            .zip(y.wires().iter())
384            .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
385            .collect::<swanky_error::Result<_>>()
386            .map(Bundle)
387    }
388}
389
390/// Extension trait for Fancy which provides Bundle constructions which are not
391/// necessarily CRT nor binary-based.
392pub trait BundleGadgets: Fancy {
393    /// Encode a bundle.
394    fn encode_bundle(
395        &mut self,
396        values: &[u16],
397        moduli: &[u16],
398        channel: &mut Channel,
399    ) -> swanky_error::Result<Bundle<Self::Item>> {
400        self.encode_many(values, moduli, channel).map(Bundle::new)
401    }
402
403    /// Receive a bundle.
404    fn receive_bundle(
405        &mut self,
406        moduli: &[u16],
407        channel: &mut Channel,
408    ) -> swanky_error::Result<Bundle<Self::Item>> {
409        self.receive_many(moduli, channel).map(Bundle::new)
410    }
411
412    /// Encode many input bundles.
413    ///
414    /// # Panics,
415    /// Panics if `values` and `moduli` are of unequal length.
416    fn encode_bundles(
417        &mut self,
418        values: &[Vec<u16>],
419        moduli: &[Vec<u16>],
420        channel: &mut Channel,
421    ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
422        let qs = moduli.iter().flatten().cloned().collect_vec();
423        let xs = values.iter().flatten().cloned().collect_vec();
424        assert_eq!(xs.len(), qs.len(), "unequal number of values and moduli");
425        let mut wires = self.encode_many(&xs, &qs, channel)?;
426        let buns = moduli
427            .iter()
428            .map(|qs| {
429                let ws = wires.drain(0..qs.len()).collect_vec();
430                Bundle::new(ws)
431            })
432            .collect_vec();
433        Ok(buns)
434    }
435
436    /// Receive many input bundles.
437    fn receive_many_bundles(
438        &mut self,
439        moduli: &[Vec<u16>],
440        channel: &mut Channel,
441    ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
442        let qs = moduli.iter().flatten().cloned().collect_vec();
443        let mut wires = self.receive_many(&qs, channel)?;
444        let buns = moduli
445            .iter()
446            .map(|qs| {
447                let ws = wires.drain(0..qs.len()).collect_vec();
448                Bundle::new(ws)
449            })
450            .collect_vec();
451        Ok(buns)
452    }
453
454    /// Creates a bundle of constant wires using moduli `ps`.
455    fn constant_bundle(
456        &mut self,
457        xs: &[u16],
458        ps: &[u16],
459        channel: &mut Channel,
460    ) -> swanky_error::Result<Bundle<Self::Item>> {
461        xs.iter()
462            .zip(ps.iter())
463            .map(|(&x, &p)| self.constant(x, p, channel))
464            .collect::<swanky_error::Result<_>>()
465            .map(Bundle)
466    }
467
468    /// Output the wires that make up a bundle.
469    fn output_bundle(
470        &mut self,
471        x: &Bundle<Self::Item>,
472        channel: &mut Channel,
473    ) -> swanky_error::Result<Option<Vec<u16>>> {
474        let ws = x.wires();
475        let mut outputs = Vec::with_capacity(ws.len());
476        for w in ws.iter() {
477            outputs.push(self.output(w, channel)?);
478        }
479        Ok(outputs.into_iter().collect())
480    }
481
482    /// Output a slice of bundles.
483    fn output_bundles(
484        &mut self,
485        xs: &[Bundle<Self::Item>],
486        channel: &mut Channel,
487    ) -> swanky_error::Result<Option<Vec<Vec<u16>>>> {
488        let mut zs = Vec::with_capacity(xs.len());
489        for x in xs.iter() {
490            let z = self.output_bundle(x, channel)?;
491            zs.push(z);
492        }
493        Ok(zs.into_iter().collect())
494    }
495
496    ////////////////////////////////////////////////////////////////////////////////
497    // gadgets which are neither CRT or binary
498
499    /// Shift residues, replacing them with zeros in the modulus of the least signifigant
500    /// residue. Maintains the length of the input.
501    fn shift(
502        &mut self,
503        x: &Bundle<Self::Item>,
504        n: usize,
505        channel: &mut Channel,
506    ) -> swanky_error::Result<Bundle<Self::Item>> {
507        let mut ws = x.wires().to_vec();
508        let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
509        for _ in 0..n {
510            ws.pop();
511            ws.insert(0, zero.clone());
512        }
513        Ok(Bundle(ws))
514    }
515
516    /// Shift residues, replacing them with zeros in the modulus of the least signifigant
517    /// residue. Output is extended with n elements.
518    fn shift_extend(
519        &mut self,
520        x: &Bundle<Self::Item>,
521        n: usize,
522        channel: &mut Channel,
523    ) -> swanky_error::Result<Bundle<Self::Item>> {
524        let mut ws = x.wires().to_vec();
525        let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
526        for _ in 0..n {
527            ws.insert(0, zero.clone());
528        }
529        Ok(Bundle(ws))
530    }
531}