fancy_garbling/fancy/
bundle.rs

1use crate::{
2    FancyArithmetic, FancyBinary,
3    fancy::{Fancy, HasModulus},
4};
5use itertools::Itertools;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::ops::Index;
9use swanky_channel::Channel;
10
11/// A collection of wires, useful for the garbled gadgets defined by `BundleGadgets`.
12#[derive(Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct Bundle<W>(Vec<W>);
15
16impl<W: Clone + HasModulus> Bundle<W> {
17    /// Create a new bundle from some wires.
18    pub fn new(ws: Vec<W>) -> Bundle<W> {
19        Bundle(ws)
20    }
21
22    /// Return the moduli of all the wires in the bundle.
23    pub fn moduli(&self) -> Vec<u16> {
24        self.0.iter().map(HasModulus::modulus).collect()
25    }
26
27    /// Extract the wires from this bundle.
28    pub fn wires(&self) -> &[W] {
29        &self.0
30    }
31
32    /// Get the number of wires in this bundle.
33    pub fn size(&self) -> usize {
34        self.0.len()
35    }
36
37    /// Whether this bundle only contains residues in mod 2.
38    pub fn is_binary(&self) -> bool {
39        self.moduli().iter().all(|m| *m == 2)
40    }
41
42    /// Returns a new bundle only containing wires with matching moduli.
43    pub fn with_moduli(&self, moduli: &[u16]) -> Bundle<W> {
44        let old_ws = self.wires();
45        let mut new_ws = Vec::with_capacity(moduli.len());
46        for &p in moduli {
47            if let Some(w) = old_ws.iter().find(|&x| x.modulus() == p) {
48                new_ws.push(w.clone());
49            } else {
50                panic!("Bundle::with_moduli: no {} modulus in bundle", p);
51            }
52        }
53        Bundle(new_ws)
54    }
55
56    /// Pad the Bundle with val, n times.
57    pub fn pad(&mut self, val: &W, n: usize) {
58        for _ in 0..n {
59            self.0.push(val.clone());
60        }
61    }
62
63    /// Extract a wire from the Bundle, removing it and returning it.
64    pub fn extract(&mut self, wire_index: usize) -> W {
65        self.0.remove(wire_index)
66    }
67
68    /// Insert a wire from the Bundle
69    pub fn insert(&mut self, wire_index: usize, val: W) {
70        self.0.insert(wire_index, val)
71    }
72
73    /// push a wire onto the Bundle.
74    pub fn push(&mut self, val: W) {
75        self.0.push(val);
76    }
77
78    /// Pop a wire from the Bundle.
79    pub fn pop(&mut self) -> Option<W> {
80        self.0.pop()
81    }
82
83    /// Access the underlying iterator
84    pub fn iter(&self) -> std::slice::Iter<'_, W> {
85        self.0.iter()
86    }
87
88    /// Reverse the wires
89    pub fn reverse(&mut self) {
90        self.0.reverse();
91    }
92}
93
94impl<W: Clone + HasModulus> Index<usize> for Bundle<W> {
95    type Output = W;
96
97    fn index(&self, idx: usize) -> &Self::Output {
98        self.0.index(idx)
99    }
100}
101
102impl<F: Fancy> BundleGadgets for F {}
103impl<F: FancyArithmetic> ArithmeticBundleGadgets for F {}
104impl<F: FancyBinary> BinaryBundleGadgets for F {}
105
106/// Arithmetic operations on wire bundles, extending the capability of `FancyArithmetic` operating
107/// on individual wires.
108pub trait ArithmeticBundleGadgets: FancyArithmetic {
109    /// Add two wire bundles pairwise, zipping addition.
110    ///
111    /// In CRT this is plain addition. In binary this is xor.
112    ///
113    /// # Panics
114    /// Panics if `x` and `y` are not of the same length.
115    fn add_bundles(
116        &mut self,
117        x: &Bundle<Self::Item>,
118        y: &Bundle<Self::Item>,
119    ) -> Bundle<Self::Item> {
120        assert_eq!(
121            x.wires().len(),
122            y.wires().len(),
123            "`x` and `y` must be the same length"
124        );
125        Bundle::new(
126            x.wires()
127                .iter()
128                .zip(y.wires().iter())
129                .map(|(x, y)| self.add(x, y))
130                .collect::<Vec<Self::Item>>(),
131        )
132    }
133
134    /// Subtract two wire bundles, residue by residue.
135    ///
136    /// In CRT this is plain subtraction. In binary this is `xor`.
137    ///
138    /// # Panics
139    /// Panics if `x` and `y` are not of the same length.
140    fn sub_bundles(
141        &mut self,
142        x: &Bundle<Self::Item>,
143        y: &Bundle<Self::Item>,
144    ) -> Bundle<Self::Item> {
145        assert_eq!(
146            x.wires().len(),
147            y.wires().len(),
148            "`x` and `y` must be the same length"
149        );
150        Bundle::new(
151            x.wires()
152                .iter()
153                .zip(y.wires().iter())
154                .map(|(x, y)| self.sub(x, y))
155                .collect::<Vec<Self::Item>>(),
156        )
157    }
158
159    /// Multiply each wire in `x` with each wire in `y`, pairwise.
160    ///
161    /// In CRT this is plain multiplication. In binary this is `and`.
162    fn mul_bundles(
163        &mut self,
164        x: &Bundle<Self::Item>,
165        y: &Bundle<Self::Item>,
166        channel: &mut Channel,
167    ) -> swanky_error::Result<Bundle<Self::Item>> {
168        x.wires()
169            .iter()
170            .zip(y.wires().iter())
171            .map(|(x, y)| self.mul(x, y, channel))
172            .collect::<swanky_error::Result<_>>()
173            .map(Bundle::new)
174    }
175
176    /// Mixed radix addition.
177    ///
178    /// # Panics
179    /// Panics if `xs` is empty, or the moduli in `xs` are not all equal.
180    fn mixed_radix_addition(
181        &mut self,
182        xs: &[Bundle<Self::Item>],
183        channel: &mut Channel,
184    ) -> swanky_error::Result<Bundle<Self::Item>> {
185        assert!(!xs.is_empty(), "`xs` cannot be empty");
186        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
187
188        let nargs = xs.len();
189        let n = xs[0].wires().len();
190
191        let mut digit_carry = None;
192        let mut carry_carry = None;
193        let mut max_carry = 0;
194
195        let mut res = Vec::with_capacity(n);
196
197        for i in 0..n {
198            // all the ith digits, in one vec
199            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
200
201            // compute the digit -- easy
202            let digit_sum = self.add_many(&ds);
203            let digit = digit_carry.map_or(digit_sum.clone(), |d| self.add(&digit_sum, &d));
204
205            if i < n - 1 {
206                // compute the carries
207                let q = xs[0].wires()[i].modulus();
208                // max_carry currently contains the max carry from the previous iteration
209                let max_val = nargs as u16 * (q - 1) + max_carry;
210                // now it is the max carry of this iteration
211                max_carry = max_val / q;
212
213                let modded_ds = ds
214                    .iter()
215                    .map(|d| self.mod_change(d, max_val + 1, channel))
216                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
217
218                let carry_sum = self.add_many(&modded_ds);
219                // add in the carry from the previous iteration
220                let carry = carry_carry.map_or(carry_sum.clone(), |c| self.add(&carry_sum, &c));
221
222                // carry now contains the carry information, we just have to project it to
223                // the correct moduli for the next iteration
224                let next_mod = xs[0].wires()[i + 1].modulus();
225                let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
226                digit_carry = Some(self.proj(&carry, next_mod, Some(tt), channel)?);
227
228                let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
229
230                if i < n - 2 {
231                    if max_carry < next_mod {
232                        carry_carry = Some(self.mod_change(
233                            digit_carry.as_ref().unwrap(),
234                            next_max_val + 1,
235                            channel,
236                        )?);
237                    } else {
238                        let tt = (0..=max_val).map(|i| i / q).collect_vec();
239                        carry_carry =
240                            Some(self.proj(&carry, next_max_val + 1, Some(tt), channel)?);
241                    }
242                } else {
243                    // next digit is MSB so we dont need carry_carry
244                    carry_carry = None;
245                }
246            } else {
247                digit_carry = None;
248                carry_carry = None;
249            }
250            res.push(digit);
251        }
252        Ok(Bundle(res))
253    }
254
255    /// Mixed radix addition only returning the MSB.
256    ///
257    /// # Panics
258    /// Panics if `xs` is empty, or the moduli in `xs` are not all equal.
259    fn mixed_radix_addition_msb_only(
260        &mut self,
261        xs: &[Bundle<Self::Item>],
262        channel: &mut Channel,
263    ) -> swanky_error::Result<Self::Item> {
264        assert!(!xs.is_empty(), "`xs` cannot be empty");
265        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
266
267        let nargs = xs.len();
268        let n = xs[0].wires().len();
269
270        let mut opt_carry = None;
271        let mut max_carry = 0;
272
273        for i in 0..n - 1 {
274            // all the ith digits, in one vec
275            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
276            // compute the carry
277            let q = xs[0].moduli()[i];
278            // max_carry currently contains the max carry from the previous iteration
279            let max_val = nargs as u16 * (q - 1) + max_carry;
280            // now it is the max carry of this iteration
281            max_carry = max_val / q;
282
283            // mod change the digits to the max sum possible plus the max carry of the
284            // previous iteration
285            let modded_ds = ds
286                .iter()
287                .map(|d| self.mod_change(d, max_val + 1, channel))
288                .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
289            // add them up
290            let sum = self.add_many(&modded_ds);
291            // add in the carry
292            let sum_with_carry = opt_carry
293                .as_ref()
294                .map_or(sum.clone(), |c| self.add(&sum, c));
295
296            // carry now contains the carry information, we just have to project it to
297            // the correct moduli for the next iteration. It will either be used to
298            // compute the next carry, if i < n-2, or it will be used to compute the
299            // output MSB, in which case it should be the modulus of the SB
300            let next_mod = if i < n - 2 {
301                nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
302            } else {
303                xs[0].moduli()[i + 1] // we will be adding the carry to the MSB
304            };
305
306            let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
307            opt_carry = Some(self.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
308        }
309
310        // compute the msb
311        let ds = xs.iter().map(|x| x.wires()[n - 1].clone()).collect_vec();
312        let digit_sum = self.add_many(&ds);
313        Ok(opt_carry
314            .as_ref()
315            .map_or(digit_sum.clone(), |d| self.add(&digit_sum, d)))
316    }
317
318    /// If b=0 then return 0, else return x.
319    fn mask(
320        &mut self,
321        b: &Self::Item,
322        x: &Bundle<Self::Item>,
323        channel: &mut Channel,
324    ) -> swanky_error::Result<Bundle<Self::Item>> {
325        x.wires()
326            .iter()
327            .map(|xwire| self.mul(xwire, b, channel))
328            .collect::<swanky_error::Result<_>>()
329            .map(Bundle)
330    }
331
332    /// Compute `x == y`. Returns a wire encoding the result mod 2.
333    ///
334    /// # Panics
335    /// Panics if `x` and `y` do not have equal moduli.
336    fn eq_bundles(
337        &mut self,
338        x: &Bundle<Self::Item>,
339        y: &Bundle<Self::Item>,
340        channel: &mut Channel,
341    ) -> swanky_error::Result<Self::Item> {
342        assert_eq!(x.moduli(), y.moduli());
343
344        let wlen = x.wires().len() as u16;
345        let zs = x
346            .wires()
347            .iter()
348            .zip_eq(y.wires().iter())
349            .map(|(x, y)| {
350                // compute (x-y == 0) for each residue
351                let z = self.sub(x, y);
352                let mut eq_zero_tab = vec![0; x.modulus() as usize];
353                eq_zero_tab[0] = 1;
354                self.proj(&z, wlen + 1, Some(eq_zero_tab), channel)
355            })
356            .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
357        // add up the results, and output whether they equal zero or not, mod 2
358        let z = self.add_many(&zs);
359        let b = zs.len();
360        let mut tab = vec![0; b + 1];
361        tab[b] = 1;
362        self.proj(&z, 2, Some(tab), channel)
363    }
364}
365
366/// Binary operations on wire bundles, extending the capability of `FancyBinary` operating
367/// on individual wires.
368pub trait BinaryBundleGadgets: FancyBinary {
369    /// If b=0 then return x, else return y.
370    fn multiplex(
371        &mut self,
372        b: &Self::Item,
373        x: &Bundle<Self::Item>,
374        y: &Bundle<Self::Item>,
375        channel: &mut Channel,
376    ) -> swanky_error::Result<Bundle<Self::Item>> {
377        x.wires()
378            .iter()
379            .zip(y.wires().iter())
380            .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
381            .collect::<swanky_error::Result<_>>()
382            .map(Bundle)
383    }
384}
385
386/// Extension trait for Fancy which provides Bundle constructions which are not
387/// necessarily CRT nor binary-based.
388pub trait BundleGadgets: Fancy {
389    /// Creates a bundle of constant wires using moduli `ps`.
390    fn constant_bundle(
391        &mut self,
392        xs: &[u16],
393        ps: &[u16],
394        channel: &mut Channel,
395    ) -> swanky_error::Result<Bundle<Self::Item>> {
396        xs.iter()
397            .zip(ps.iter())
398            .map(|(&x, &p)| self.constant(x, p, channel))
399            .collect::<swanky_error::Result<_>>()
400            .map(Bundle)
401    }
402
403    /// Output the wires that make up a bundle.
404    fn output_bundle(
405        &mut self,
406        x: &Bundle<Self::Item>,
407        channel: &mut Channel,
408    ) -> swanky_error::Result<Option<Vec<u16>>> {
409        let ws = x.wires();
410        let mut outputs = Vec::with_capacity(ws.len());
411        for w in ws.iter() {
412            outputs.push(self.output(w, channel)?);
413        }
414        Ok(outputs.into_iter().collect())
415    }
416
417    /// Output a slice of bundles.
418    fn output_bundles(
419        &mut self,
420        xs: &[Bundle<Self::Item>],
421        channel: &mut Channel,
422    ) -> swanky_error::Result<Option<Vec<Vec<u16>>>> {
423        let mut zs = Vec::with_capacity(xs.len());
424        for x in xs.iter() {
425            let z = self.output_bundle(x, channel)?;
426            zs.push(z);
427        }
428        Ok(zs.into_iter().collect())
429    }
430
431    ////////////////////////////////////////////////////////////////////////////////
432    // gadgets which are neither CRT or binary
433
434    /// Shift residues, replacing them with zeros in the modulus of the least signifigant
435    /// residue. Maintains the length of the input.
436    fn shift(
437        &mut self,
438        x: &Bundle<Self::Item>,
439        n: usize,
440        channel: &mut Channel,
441    ) -> swanky_error::Result<Bundle<Self::Item>> {
442        let mut ws = x.wires().to_vec();
443        let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
444        for _ in 0..n {
445            ws.pop();
446            ws.insert(0, zero.clone());
447        }
448        Ok(Bundle(ws))
449    }
450
451    /// Shift residues, replacing them with zeros in the modulus of the least signifigant
452    /// residue. Output is extended with n elements.
453    fn shift_extend(
454        &mut self,
455        x: &Bundle<Self::Item>,
456        n: usize,
457        channel: &mut Channel,
458    ) -> swanky_error::Result<Bundle<Self::Item>> {
459        let mut ws = x.wires().to_vec();
460        let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
461        for _ in 0..n {
462            ws.insert(0, zero.clone());
463        }
464        Ok(Bundle(ws))
465    }
466}