fancy_garbling/fancy/
crt.rs

1//! Module containing `CrtGadgets`, which are the CRT-based gadgets for `Fancy`.
2
3use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5    FancyArithmetic, FancyBinary,
6    errors::FancyError,
7    fancy::bundle::{Bundle, BundleGadgets},
8    util,
9};
10use itertools::Itertools;
11use std::ops::Deref;
12
13/// Bundle which is explicitly CRT-representation.
14#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18    /// Create a new CRT bundle from a vector of wires.
19    pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20        CrtBundle(Bundle::new(ws))
21    }
22
23    /// Extract the underlying bundle from this CRT bundle.
24    pub fn extract(self) -> Bundle<W> {
25        self.0
26    }
27
28    /// Return the product of all the wires' moduli.
29    pub fn composite_modulus(&self) -> u128 {
30        util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31    }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35    type Target = Bundle<W>;
36
37    fn deref(&self) -> &Bundle<W> {
38        &self.0
39    }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43    fn from(b: Bundle<W>) -> CrtBundle<W> {
44        CrtBundle(b)
45    }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50/// Extension trait for `Fancy` providing advanced CRT gadgets based on bundles of wires.
51pub trait CrtGadgets:
52    FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54    /// Creates a bundle of constant wires for the CRT representation of `x` under
55    /// composite modulus `q`.
56    fn crt_constant_bundle(
57        &mut self,
58        x: u128,
59        q: u128,
60    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
61        let ps = util::factor(q);
62        let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
63        self.constant_bundle(&xs, &ps).map(CrtBundle)
64    }
65
66    /// Output a CRT bundle and interpret it mod Q.
67    fn crt_output(&mut self, x: &CrtBundle<Self::Item>) -> Result<Option<u128>, Self::Error> {
68        let q = x.composite_modulus();
69        Ok(self
70            .output_bundle(x)?
71            .map(|xs| util::crt_inv_factor(&xs, q)))
72    }
73
74    /// Output a slice of CRT bundles and interpret the outputs mod Q.
75    fn crt_outputs(
76        &mut self,
77        xs: &[CrtBundle<Self::Item>],
78    ) -> Result<Option<Vec<u128>>, Self::Error> {
79        let mut zs = Vec::with_capacity(xs.len());
80        for x in xs.iter() {
81            let z = self.crt_output(x)?;
82            zs.push(z);
83        }
84        Ok(zs.into_iter().collect())
85    }
86
87    ////////////////////////////////////////////////////////////////////////////////
88    // High-level computations dealing with bundles.
89
90    /// Add two CRT bundles.
91    fn crt_add(
92        &mut self,
93        x: &CrtBundle<Self::Item>,
94        y: &CrtBundle<Self::Item>,
95    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
96        self.add_bundles(x, y).map(CrtBundle)
97    }
98
99    /// Subtract two CRT bundles.
100    fn crt_sub(
101        &mut self,
102        x: &CrtBundle<Self::Item>,
103        y: &CrtBundle<Self::Item>,
104    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
105        self.sub_bundles(x, y).map(CrtBundle)
106    }
107
108    /// Multiplies each wire in `x` by the corresponding residue of `c`.
109    fn crt_cmul(
110        &mut self,
111        x: &CrtBundle<Self::Item>,
112        c: u128,
113    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
114        let cs = util::crt(c, &x.moduli());
115        x.wires()
116            .iter()
117            .zip(cs.into_iter())
118            .map(|(x, c)| self.cmul(x, c))
119            .collect::<Result<Vec<Self::Item>, Self::Error>>()
120            .map(CrtBundle::new)
121    }
122
123    /// Multiply `x` with `y`.
124    fn crt_mul(
125        &mut self,
126        x: &CrtBundle<Self::Item>,
127        y: &CrtBundle<Self::Item>,
128    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
129        self.mul_bundles(x, y).map(CrtBundle)
130    }
131
132    /// Exponentiate `x` by the constant `c`.
133    fn crt_cexp(
134        &mut self,
135        x: &CrtBundle<Self::Item>,
136        c: u16,
137    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
138        x.wires()
139            .iter()
140            .map(|x| {
141                let p = x.modulus();
142                let tab = (0..p)
143                    .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
144                    .collect_vec();
145                self.proj(x, p, Some(tab))
146            })
147            .collect::<Result<Vec<Self::Item>, Self::Error>>()
148            .map(CrtBundle::new)
149    }
150
151    /// Compute the remainder with respect to modulus `p`.
152    fn crt_rem(
153        &mut self,
154        x: &CrtBundle<Self::Item>,
155        p: u16,
156    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
157        let i = x.moduli().iter().position(|&q| p == q).ok_or_else(|| {
158            Self::Error::from(FancyError::InvalidArg(
159                "p is not a modulus in this bundle!".to_string(),
160            ))
161        })?;
162        let w = &x.wires()[i];
163        x.moduli()
164            .iter()
165            .map(|&q| self.mod_change(w, q))
166            .collect::<Result<Vec<Self::Item>, Self::Error>>()
167            .map(CrtBundle::new)
168    }
169
170    ////////////////////////////////////////////////////////////////////////////////
171    // Fancy functions based on Mike's fractional mixed radix trick.
172
173    /// Helper function for advanced gadgets, returns the MSB of the fractional part of
174    /// `X/M` where `M=product(ms)`.
175    fn crt_fractional_mixed_radix(
176        &mut self,
177        bun: &CrtBundle<Self::Item>,
178        ms: &[u16],
179    ) -> Result<Self::Item, Self::Error> {
180        let ndigits = ms.len();
181
182        let q = util::product(&bun.moduli());
183        let M = util::product(ms);
184
185        let mut ds = Vec::new();
186
187        for wire in bun.wires().iter() {
188            let p = wire.modulus();
189
190            let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
191
192            for x in 0..p {
193                let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
194                let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
195                let digits = util::as_mixed_radix(y, ms);
196                for i in 0..ndigits {
197                    tabs[i].push(digits[i]);
198                }
199            }
200
201            let new_ds = tabs
202                .into_iter()
203                .enumerate()
204                .map(|(i, tt)| self.proj(wire, ms[i], Some(tt)))
205                .collect::<Result<Vec<Self::Item>, Self::Error>>()?;
206
207            ds.push(Bundle::new(new_ds));
208        }
209
210        self.mixed_radix_addition_msb_only(&ds)
211    }
212
213    /// Compute `max(x,0)`.
214    ///
215    /// Optional output moduli.
216    fn crt_relu(
217        &mut self,
218        x: &CrtBundle<Self::Item>,
219        accuracy: &str,
220        output_moduli: Option<&[u16]>,
221    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
222        let factors_of_m = &get_ms(x, accuracy);
223        let res = self.crt_fractional_mixed_radix(x, factors_of_m)?;
224
225        // project the MSB to 0/1, whether or not it is less than p/2
226        let p = *factors_of_m.last().unwrap();
227        let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
228        let mask = self.proj(&res, 2, Some(mask_tt))?;
229
230        // use the mask to either output x or 0
231        output_moduli
232            .map(|ps| x.with_moduli(ps))
233            .as_ref()
234            .unwrap_or(x)
235            .wires()
236            .iter()
237            .map(|x| self.mul(x, &mask))
238            .collect::<Result<Vec<Self::Item>, Self::Error>>()
239            .map(CrtBundle::new)
240    }
241
242    /// Return 0 if `x` is positive and 1 if `x` is negative.
243    fn crt_sign(
244        &mut self,
245        x: &CrtBundle<Self::Item>,
246        accuracy: &str,
247    ) -> Result<Self::Item, Self::Error> {
248        let factors_of_m = &get_ms(x, accuracy);
249        let res = self.crt_fractional_mixed_radix(x, factors_of_m)?;
250        let p = *factors_of_m.last().unwrap();
251        let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
252        self.proj(&res, 2, Some(tt))
253    }
254
255    /// Return `if x >= 0 then 1 else -1`, where `-1` is interpreted as `Q-1`.
256    ///
257    /// If provided, will produce a bundle under `output_moduli` instead of `x.moduli()`
258    fn crt_sgn(
259        &mut self,
260        x: &CrtBundle<Self::Item>,
261        accuracy: &str,
262        output_moduli: Option<&[u16]>,
263    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
264        let sign = self.crt_sign(x, accuracy)?;
265        output_moduli
266            .unwrap_or(&x.moduli())
267            .iter()
268            .map(|&p| {
269                let tt = vec![1, p - 1];
270                self.proj(&sign, p, Some(tt))
271            })
272            .collect::<Result<Vec<Self::Item>, Self::Error>>()
273            .map(CrtBundle::new)
274    }
275
276    /// Returns 1 if `x < y`.
277    fn crt_lt(
278        &mut self,
279        x: &CrtBundle<Self::Item>,
280        y: &CrtBundle<Self::Item>,
281        accuracy: &str,
282    ) -> Result<Self::Item, Self::Error> {
283        let z = self.crt_sub(x, y)?;
284        self.crt_sign(&z, accuracy)
285    }
286
287    /// Returns 1 if `x >= y`.
288    fn crt_geq(
289        &mut self,
290        x: &CrtBundle<Self::Item>,
291        y: &CrtBundle<Self::Item>,
292        accuracy: &str,
293    ) -> Result<Self::Item, Self::Error> {
294        let z = self.crt_lt(x, y, accuracy)?;
295        self.negate(&z)
296    }
297
298    /// Compute the maximum bundle in `xs`.
299    fn crt_max(
300        &mut self,
301        xs: &[CrtBundle<Self::Item>],
302        accuracy: &str,
303    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
304        if xs.is_empty() {
305            return Err(Self::Error::from(FancyError::InvalidArgNum {
306                got: xs.len(),
307                needed: 1,
308            }));
309        }
310        xs.iter().skip(1).fold(Ok(xs[0].clone()), |x, y| {
311            x.map(|x| {
312                let pos = self.crt_lt(&x, y, accuracy)?;
313                let neg = self.negate(&pos)?;
314                x.wires()
315                    .iter()
316                    .zip(y.wires().iter())
317                    .map(|(x, y)| {
318                        let xp = self.mul(x, &neg)?;
319                        let yp = self.mul(y, &pos)?;
320                        self.add(&xp, &yp)
321                    })
322                    .collect::<Result<Vec<Self::Item>, Self::Error>>()
323                    .map(CrtBundle::new)
324            })?
325        })
326    }
327
328    /// Convert the xs bundle to PMR representation. Useful for extracting out of CRT.
329    fn crt_to_pmr(
330        &mut self,
331        xs: &CrtBundle<Self::Item>,
332    ) -> Result<Bundle<Self::Item>, Self::Error> {
333        let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
334            let pq = p as u32 + q as u32 - 1;
335            let mut tab = Vec::with_capacity(pq as usize);
336            for z in 0..pq {
337                let mut x = 0;
338                let mut y = 0;
339                'outer: for i in 0..p as u32 {
340                    for j in 0..q as u32 {
341                        if (i + pq - j) % pq == z {
342                            x = i;
343                            y = j;
344                            break 'outer;
345                        }
346                    }
347                }
348                debug_assert_eq!((x + pq - y) % pq, z);
349                tab.push(
350                    (((x * q as u32 * util::inv(q as i128, p as i128) as u32
351                        + y * p as u32 * util::inv(p as i128, q as i128) as u32)
352                        / p as u32)
353                        % q as u32) as u16,
354                );
355            }
356            tab
357        };
358
359        let mut gadget = |x: &Self::Item, y: &Self::Item| -> Result<Self::Item, Self::Error> {
360            let p = x.modulus();
361            let q = y.modulus();
362            let x_ = self.mod_change(x, p + q - 1)?;
363            let y_ = self.mod_change(y, p + q - 1)?;
364            let z = self.sub(&x_, &y_)?;
365            self.proj(&z, q, Some(gadget_projection_tt(p, q)))
366        };
367
368        let n = xs.size();
369        let mut x = vec![vec![None; n + 1]; n + 1];
370
371        for j in 0..n {
372            x[0][j + 1] = Some(xs.wires()[j].clone());
373        }
374
375        for i in 1..=n {
376            for j in i + 1..=n {
377                let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
378                x[i][j] = Some(z);
379            }
380        }
381
382        let mut zwires = Vec::with_capacity(n);
383        for i in 0..n {
384            zwires.push(x[i][i + 1].take().unwrap());
385        }
386        Ok(Bundle::new(zwires))
387    }
388
389    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
390    /// it to work, there must be an extra modulus in the CRT that is not necessary to
391    /// represent the values. This ensures that if x < y, the most significant PMR digit
392    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
393    /// before using this gadget.
394    fn pmr_lt(
395        &mut self,
396        x: &CrtBundle<Self::Item>,
397        y: &CrtBundle<Self::Item>,
398    ) -> Result<Self::Item, Self::Error> {
399        let z = self.crt_sub(x, y)?;
400        let mut pmr = self.crt_to_pmr(&z)?;
401        let w = pmr.pop().unwrap();
402        let mut tab = vec![1; w.modulus() as usize];
403        tab[0] = 0;
404        self.proj(&w, 2, Some(tab))
405    }
406
407    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
408    /// it to work, there must be an extra modulus in the CRT that is not necessary to
409    /// represent the values. This ensures that if x < y, the most significant PMR digit
410    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
411    /// before using this gadget.
412    fn pmr_geq(
413        &mut self,
414        x: &CrtBundle<Self::Item>,
415        y: &CrtBundle<Self::Item>,
416    ) -> Result<Self::Item, Self::Error> {
417        let z = self.pmr_lt(x, y)?;
418        self.negate(&z)
419    }
420
421    /// Generic, and expensive, CRT-based addition for two ciphertexts. Uses PMR
422    /// comparison repeatedly. Requires an extra unused prime in both inputs.
423    fn crt_div(
424        &mut self,
425        x: &CrtBundle<Self::Item>,
426        y: &CrtBundle<Self::Item>,
427    ) -> Result<CrtBundle<Self::Item>, Self::Error> {
428        if x.moduli() != y.moduli() {
429            return Err(Self::Error::from(FancyError::UnequalModuli));
430        }
431
432        let q = x.composite_modulus();
433
434        // Compute l based on the assumption that the last prime is unused.
435        let nprimes = x.moduli().len();
436        let qs_ = &x.moduli()[..nprimes - 1];
437        let q_ = util::product(qs_);
438        let l = 128 - q_.leading_zeros();
439
440        let mut quotient = self.crt_constant_bundle(0, q)?;
441        let mut a = x.clone();
442
443        let one = self.crt_constant_bundle(1, q)?;
444        for i in 0..l {
445            let b = 2u128.pow(l - i - 1);
446            let mut pb = q_ / b;
447            if q_ % b == 0 {
448                pb -= 1;
449            }
450
451            let tmp = self.crt_cmul(y, b)?;
452            let c1 = self.pmr_geq(&a, &tmp)?;
453
454            let pb_crt = self.crt_constant_bundle(pb, q)?;
455            let c2 = self.pmr_geq(&pb_crt, y)?;
456
457            let c = self.and(&c1, &c2)?;
458
459            let c_ws = one
460                .iter()
461                .map(|w| self.mul(w, &c))
462                .collect::<Result<Vec<_>, _>>()?;
463            let c_crt = CrtBundle::new(c_ws);
464
465            let b_if = self.crt_cmul(&c_crt, b)?;
466            quotient = self.crt_add(&quotient, &b_if)?;
467
468            let tmp_if = self.crt_mul(&c_crt, &tmp)?;
469            a = self.crt_sub(&a, &tmp_if)?;
470        }
471
472        Ok(quotient)
473    }
474}
475
476/// Compute the `ms` needed for the number of CRT primes in `x`, with accuracy
477/// `accuracy`.
478///
479/// Supported accuracy: ["100%", "99.9%", "99%"]
480fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
481    match accuracy {
482        "100%" => match x.moduli().len() {
483            3 => vec![2; 5],
484            4 => vec![3, 26],
485            5 => vec![3, 4, 54],
486            6 => vec![5, 5, 5, 60],
487            7 => vec![5, 6, 6, 7, 86],
488            8 => vec![5, 7, 8, 8, 9, 98],
489            9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
490            10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
491            11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
492            n => panic!("unknown exact Ms for {} primes!", n),
493        },
494        "99.999%" => match x.moduli().len() {
495            8 => vec![5, 5, 6, 7, 102],
496            9 => vec![5, 5, 6, 7, 114],
497            10 => vec![5, 6, 6, 7, 102],
498            11 => vec![5, 5, 6, 7, 130],
499            n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
500        },
501        "99.99%" => match x.moduli().len() {
502            6 => vec![5, 5, 5, 42],
503            7 => vec![4, 5, 6, 88],
504            8 => vec![4, 5, 7, 78],
505            9 => vec![5, 5, 6, 84],
506            10 => vec![4, 5, 6, 112],
507            11 => vec![7, 11, 174],
508            n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
509        },
510        "99.9%" => match x.moduli().len() {
511            5 => vec![3, 5, 30],
512            6 => vec![4, 5, 48],
513            7 => vec![4, 5, 60],
514            8 => vec![3, 5, 78],
515            9 => vec![9, 140],
516            10 => vec![7, 190],
517            n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
518        },
519        "99%" => match x.moduli().len() {
520            4 => vec![3, 18],
521            5 => vec![3, 36],
522            6 => vec![3, 40],
523            7 => vec![3, 40],
524            8 => vec![126],
525            9 => vec![138],
526            10 => vec![140],
527            n => panic!("unknown 99% accurate Ms for {} primes!", n),
528        },
529        _ => panic!("get_ms: unsupported accuracy {}", accuracy),
530    }
531}