fancy_garbling/fancy/
crt.rs

1//! Module containing `CrtGadgets`, which are the CRT-based gadgets for `Fancy`.
2
3use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5    FancyArithmetic, FancyBinary,
6    fancy::bundle::{Bundle, BundleGadgets},
7    util,
8};
9use itertools::Itertools;
10use std::ops::Deref;
11use swanky_channel::Channel;
12
13/// Bundle which is explicitly CRT-representation.
14#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18    /// Create a new CRT bundle from a vector of wires.
19    pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20        CrtBundle(Bundle::new(ws))
21    }
22
23    /// Extract the underlying bundle from this CRT bundle.
24    pub fn extract(self) -> Bundle<W> {
25        self.0
26    }
27
28    /// Return the product of all the wires' moduli.
29    pub fn composite_modulus(&self) -> u128 {
30        util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31    }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35    type Target = Bundle<W>;
36
37    fn deref(&self) -> &Bundle<W> {
38        &self.0
39    }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43    fn from(b: Bundle<W>) -> CrtBundle<W> {
44        CrtBundle(b)
45    }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50/// Extension trait for `Fancy` providing advanced CRT gadgets based on bundles of wires.
51pub trait CrtGadgets:
52    FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54    /// Creates a bundle of constant wires for the CRT representation of `x` under
55    /// composite modulus `q`.
56    fn crt_constant_bundle(
57        &mut self,
58        x: u128,
59        q: u128,
60        channel: &mut Channel,
61    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
62        let ps = util::factor(q);
63        let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
64        self.constant_bundle(&xs, &ps, channel).map(CrtBundle)
65    }
66
67    /// Output a CRT bundle and interpret it mod Q.
68    fn crt_output(
69        &mut self,
70        x: &CrtBundle<Self::Item>,
71        channel: &mut Channel,
72    ) -> swanky_error::Result<Option<u128>> {
73        let q = x.composite_modulus();
74        Ok(self
75            .output_bundle(x, channel)?
76            .map(|xs| util::crt_inv_factor(&xs, q)))
77    }
78
79    /// Output a slice of CRT bundles and interpret the outputs mod Q.
80    fn crt_outputs(
81        &mut self,
82        xs: &[CrtBundle<Self::Item>],
83        channel: &mut Channel,
84    ) -> swanky_error::Result<Option<Vec<u128>>> {
85        let mut zs = Vec::with_capacity(xs.len());
86        for x in xs.iter() {
87            let z = self.crt_output(x, channel)?;
88            zs.push(z);
89        }
90        Ok(zs.into_iter().collect())
91    }
92
93    ////////////////////////////////////////////////////////////////////////////////
94    // High-level computations dealing with bundles.
95
96    /// Add two CRT bundles.
97    fn crt_add(
98        &mut self,
99        x: &CrtBundle<Self::Item>,
100        y: &CrtBundle<Self::Item>,
101    ) -> CrtBundle<Self::Item> {
102        CrtBundle(self.add_bundles(x, y))
103    }
104
105    /// Subtract two CRT bundles.
106    fn crt_sub(
107        &mut self,
108        x: &CrtBundle<Self::Item>,
109        y: &CrtBundle<Self::Item>,
110    ) -> CrtBundle<Self::Item> {
111        CrtBundle(self.sub_bundles(x, y))
112    }
113
114    /// Multiplies each wire in `x` by the corresponding residue of `c`.
115    fn crt_cmul(&mut self, x: &CrtBundle<Self::Item>, c: u128) -> CrtBundle<Self::Item> {
116        let cs = util::crt(c, &x.moduli());
117        CrtBundle::new(
118            x.wires()
119                .iter()
120                .zip(cs)
121                .map(|(x, c)| self.cmul(x, c))
122                .collect::<Vec<Self::Item>>(),
123        )
124    }
125
126    /// Multiply `x` with `y`.
127    fn crt_mul(
128        &mut self,
129        x: &CrtBundle<Self::Item>,
130        y: &CrtBundle<Self::Item>,
131        channel: &mut Channel,
132    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
133        self.mul_bundles(x, y, channel).map(CrtBundle)
134    }
135
136    /// Exponentiate `x` by the constant `c`.
137    fn crt_cexp(
138        &mut self,
139        x: &CrtBundle<Self::Item>,
140        c: u16,
141        channel: &mut Channel,
142    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
143        x.wires()
144            .iter()
145            .map(|x| {
146                let p = x.modulus();
147                let tab = (0..p)
148                    .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
149                    .collect_vec();
150                self.proj(x, p, Some(tab), channel)
151            })
152            .collect::<swanky_error::Result<_>>()
153            .map(CrtBundle::new)
154    }
155
156    /// Compute the remainder with respect to modulus `p`.
157    ///
158    /// # Panics
159    /// Panics if `p` is not a modulus contained in `x`.
160    fn crt_rem(
161        &mut self,
162        x: &CrtBundle<Self::Item>,
163        p: u16,
164        channel: &mut Channel,
165    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
166        let i = x.moduli().iter().position(|&q| p == q);
167        assert!(i.is_some(), "`p` is not a modulus in the `x` bundle");
168        let i = i.unwrap();
169        let w = &x.wires()[i];
170        x.moduli()
171            .iter()
172            .map(|&q| self.mod_change(w, q, channel))
173            .collect::<swanky_error::Result<_>>()
174            .map(CrtBundle::new)
175    }
176
177    ////////////////////////////////////////////////////////////////////////////////
178    // Fancy functions based on Mike's fractional mixed radix trick.
179
180    /// Helper function for advanced gadgets, returns the MSB of the fractional part of
181    /// `X/M` where `M=product(ms)`.
182    fn crt_fractional_mixed_radix(
183        &mut self,
184        bun: &CrtBundle<Self::Item>,
185        ms: &[u16],
186        channel: &mut Channel,
187    ) -> swanky_error::Result<Self::Item> {
188        let ndigits = ms.len();
189
190        let q = util::product(&bun.moduli());
191        let M = util::product(ms);
192
193        let mut ds = Vec::new();
194
195        for wire in bun.wires().iter() {
196            let p = wire.modulus();
197
198            let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
199
200            for x in 0..p {
201                let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
202                let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
203                let digits = util::as_mixed_radix(y, ms);
204                for i in 0..ndigits {
205                    tabs[i].push(digits[i]);
206                }
207            }
208
209            let new_ds = tabs
210                .into_iter()
211                .enumerate()
212                .map(|(i, tt)| self.proj(wire, ms[i], Some(tt), channel))
213                .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
214
215            ds.push(Bundle::new(new_ds));
216        }
217
218        self.mixed_radix_addition_msb_only(&ds, channel)
219    }
220
221    /// Compute `max(x,0)`.
222    ///
223    /// Optional output moduli.
224    fn crt_relu(
225        &mut self,
226        x: &CrtBundle<Self::Item>,
227        accuracy: &str,
228        output_moduli: Option<&[u16]>,
229        channel: &mut Channel,
230    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
231        let factors_of_m = &get_ms(x, accuracy);
232        let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
233
234        // project the MSB to 0/1, whether or not it is less than p/2
235        let p = *factors_of_m.last().unwrap();
236        let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
237        let mask = self.proj(&res, 2, Some(mask_tt), channel)?;
238
239        // use the mask to either output x or 0
240        output_moduli
241            .map(|ps| x.with_moduli(ps))
242            .as_ref()
243            .unwrap_or(x)
244            .wires()
245            .iter()
246            .map(|x| self.mul(x, &mask, channel))
247            .collect::<swanky_error::Result<_>>()
248            .map(CrtBundle::new)
249    }
250
251    /// Return 0 if `x` is positive and 1 if `x` is negative.
252    fn crt_sign(
253        &mut self,
254        x: &CrtBundle<Self::Item>,
255        accuracy: &str,
256        channel: &mut Channel,
257    ) -> swanky_error::Result<Self::Item> {
258        let factors_of_m = &get_ms(x, accuracy);
259        let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
260        let p = *factors_of_m.last().unwrap();
261        let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
262        self.proj(&res, 2, Some(tt), channel)
263    }
264
265    /// Return `if x >= 0 then 1 else -1`, where `-1` is interpreted as `Q-1`.
266    ///
267    /// If provided, will produce a bundle under `output_moduli` instead of `x.moduli()`
268    fn crt_sgn(
269        &mut self,
270        x: &CrtBundle<Self::Item>,
271        accuracy: &str,
272        output_moduli: Option<&[u16]>,
273        channel: &mut Channel,
274    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
275        let sign = self.crt_sign(x, accuracy, channel)?;
276        output_moduli
277            .unwrap_or(&x.moduli())
278            .iter()
279            .map(|&p| {
280                let tt = vec![1, p - 1];
281                self.proj(&sign, p, Some(tt), channel)
282            })
283            .collect::<swanky_error::Result<_>>()
284            .map(CrtBundle::new)
285    }
286
287    /// Returns 1 if `x < y`.
288    fn crt_lt(
289        &mut self,
290        x: &CrtBundle<Self::Item>,
291        y: &CrtBundle<Self::Item>,
292        accuracy: &str,
293        channel: &mut Channel,
294    ) -> swanky_error::Result<Self::Item> {
295        let z = self.crt_sub(x, y);
296        self.crt_sign(&z, accuracy, channel)
297    }
298
299    /// Returns 1 if `x >= y`.
300    fn crt_geq(
301        &mut self,
302        x: &CrtBundle<Self::Item>,
303        y: &CrtBundle<Self::Item>,
304        accuracy: &str,
305        channel: &mut Channel,
306    ) -> swanky_error::Result<Self::Item> {
307        let z = self.crt_lt(x, y, accuracy, channel)?;
308        Ok(self.negate(&z))
309    }
310
311    /// Compute the maximum bundle in `xs`.
312    ///
313    /// # Panics
314    /// Panics if `xs` is empty.
315    fn crt_max(
316        &mut self,
317        xs: &[CrtBundle<Self::Item>],
318        accuracy: &str,
319        channel: &mut Channel,
320    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
321        assert!(!xs.is_empty(), "`xs` cannot be empty");
322        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
323            let pos = self.crt_lt(&x, y, accuracy, channel)?;
324            let neg = self.negate(&pos);
325            Ok(CrtBundle::new(
326                x.wires()
327                    .iter()
328                    .zip(y.wires().iter())
329                    .map(|(x, y)| {
330                        let xp = self.mul(x, &neg, channel)?;
331                        let yp = self.mul(y, &pos, channel)?;
332                        Ok(self.add(&xp, &yp))
333                    })
334                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
335            ))
336        })
337    }
338
339    /// Convert the xs bundle to PMR representation. Useful for extracting out of CRT.
340    fn crt_to_pmr(
341        &mut self,
342        xs: &CrtBundle<Self::Item>,
343        channel: &mut Channel,
344    ) -> swanky_error::Result<Bundle<Self::Item>> {
345        let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
346            let pq = p as u32 + q as u32 - 1;
347            let mut tab = Vec::with_capacity(pq as usize);
348            for z in 0..pq {
349                let mut x = 0;
350                let mut y = 0;
351                'outer: for i in 0..p as u32 {
352                    for j in 0..q as u32 {
353                        if (i + pq - j) % pq == z {
354                            x = i;
355                            y = j;
356                            break 'outer;
357                        }
358                    }
359                }
360                debug_assert_eq!((x + pq - y) % pq, z);
361                tab.push(
362                    (((x * q as u32 * util::inv(q as i128, p as i128) as u32
363                        + y * p as u32 * util::inv(p as i128, q as i128) as u32)
364                        / p as u32)
365                        % q as u32) as u16,
366                );
367            }
368            tab
369        };
370
371        let mut gadget = |x: &Self::Item, y: &Self::Item| -> swanky_error::Result<Self::Item> {
372            let p = x.modulus();
373            let q = y.modulus();
374            let x_ = self.mod_change(x, p + q - 1, channel)?;
375            let y_ = self.mod_change(y, p + q - 1, channel)?;
376            let z = self.sub(&x_, &y_);
377            self.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
378        };
379
380        let n = xs.size();
381        let mut x = vec![vec![None; n + 1]; n + 1];
382
383        for j in 0..n {
384            x[0][j + 1] = Some(xs.wires()[j].clone());
385        }
386
387        for i in 1..=n {
388            for j in i + 1..=n {
389                let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
390                x[i][j] = Some(z);
391            }
392        }
393
394        let mut zwires = Vec::with_capacity(n);
395        for i in 0..n {
396            zwires.push(x[i][i + 1].take().unwrap());
397        }
398        Ok(Bundle::new(zwires))
399    }
400
401    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
402    /// it to work, there must be an extra modulus in the CRT that is not necessary to
403    /// represent the values. This ensures that if x < y, the most significant PMR digit
404    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
405    /// before using this gadget.
406    fn pmr_lt(
407        &mut self,
408        x: &CrtBundle<Self::Item>,
409        y: &CrtBundle<Self::Item>,
410        channel: &mut Channel,
411    ) -> swanky_error::Result<Self::Item> {
412        let z = self.crt_sub(x, y);
413        let mut pmr = self.crt_to_pmr(&z, channel)?;
414        let w = pmr.pop().unwrap();
415        let mut tab = vec![1; w.modulus() as usize];
416        tab[0] = 0;
417        self.proj(&w, 2, Some(tab), channel)
418    }
419
420    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
421    /// it to work, there must be an extra modulus in the CRT that is not necessary to
422    /// represent the values. This ensures that if x < y, the most significant PMR digit
423    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
424    /// before using this gadget.
425    fn pmr_geq(
426        &mut self,
427        x: &CrtBundle<Self::Item>,
428        y: &CrtBundle<Self::Item>,
429        channel: &mut Channel,
430    ) -> swanky_error::Result<Self::Item> {
431        let z = self.pmr_lt(x, y, channel)?;
432        Ok(self.negate(&z))
433    }
434
435    /// Generic, and expensive, CRT-based addition for two ciphertexts. Uses PMR
436    /// comparison repeatedly. Requires an extra unused prime in both inputs.
437    ///
438    /// # Panics
439    /// Panics if `x` and `y` do not have equal moduli.
440    fn crt_div(
441        &mut self,
442        x: &CrtBundle<Self::Item>,
443        y: &CrtBundle<Self::Item>,
444        channel: &mut Channel,
445    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
446        assert_eq!(x.moduli(), y.moduli());
447
448        let q = x.composite_modulus();
449
450        // Compute l based on the assumption that the last prime is unused.
451        let nprimes = x.moduli().len();
452        let qs_ = &x.moduli()[..nprimes - 1];
453        let q_ = util::product(qs_);
454        let l = 128 - q_.leading_zeros();
455
456        let mut quotient = self.crt_constant_bundle(0, q, channel)?;
457        let mut a = x.clone();
458
459        let one = self.crt_constant_bundle(1, q, channel)?;
460        for i in 0..l {
461            let b = 2u128.pow(l - i - 1);
462            let mut pb = q_ / b;
463            if q_.is_multiple_of(b) {
464                pb -= 1;
465            }
466
467            let tmp = self.crt_cmul(y, b);
468            let c1 = self.pmr_geq(&a, &tmp, channel)?;
469
470            let pb_crt = self.crt_constant_bundle(pb, q, channel)?;
471            let c2 = self.pmr_geq(&pb_crt, y, channel)?;
472
473            let c = self.and(&c1, &c2, channel)?;
474
475            let c_ws = one
476                .iter()
477                .map(|w| self.mul(w, &c, channel))
478                .collect::<Result<Vec<_>, _>>()?;
479            let c_crt = CrtBundle::new(c_ws);
480
481            let b_if = self.crt_cmul(&c_crt, b);
482            quotient = self.crt_add(&quotient, &b_if);
483
484            let tmp_if = self.crt_mul(&c_crt, &tmp, channel)?;
485            a = self.crt_sub(&a, &tmp_if);
486        }
487
488        Ok(quotient)
489    }
490}
491
492/// Compute the `ms` needed for the number of CRT primes in `x`, with accuracy
493/// `accuracy`.
494///
495/// Supported accuracy: ["100%", "99.9%", "99%"]
496fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
497    match accuracy {
498        "100%" => match x.moduli().len() {
499            3 => vec![2; 5],
500            4 => vec![3, 26],
501            5 => vec![3, 4, 54],
502            6 => vec![5, 5, 5, 60],
503            7 => vec![5, 6, 6, 7, 86],
504            8 => vec![5, 7, 8, 8, 9, 98],
505            9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
506            10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
507            11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
508            n => panic!("unknown exact Ms for {} primes!", n),
509        },
510        "99.999%" => match x.moduli().len() {
511            8 => vec![5, 5, 6, 7, 102],
512            9 => vec![5, 5, 6, 7, 114],
513            10 => vec![5, 6, 6, 7, 102],
514            11 => vec![5, 5, 6, 7, 130],
515            n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
516        },
517        "99.99%" => match x.moduli().len() {
518            6 => vec![5, 5, 5, 42],
519            7 => vec![4, 5, 6, 88],
520            8 => vec![4, 5, 7, 78],
521            9 => vec![5, 5, 6, 84],
522            10 => vec![4, 5, 6, 112],
523            11 => vec![7, 11, 174],
524            n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
525        },
526        "99.9%" => match x.moduli().len() {
527            5 => vec![3, 5, 30],
528            6 => vec![4, 5, 48],
529            7 => vec![4, 5, 60],
530            8 => vec![3, 5, 78],
531            9 => vec![9, 140],
532            10 => vec![7, 190],
533            n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
534        },
535        "99%" => match x.moduli().len() {
536            4 => vec![3, 18],
537            5 => vec![3, 36],
538            6 => vec![3, 40],
539            7 => vec![3, 40],
540            8 => vec![126],
541            9 => vec![138],
542            10 => vec![140],
543            n => panic!("unknown 99% accurate Ms for {} primes!", n),
544        },
545        _ => panic!("get_ms: unsupported accuracy {}", accuracy),
546    }
547}