Skip to main content

fancy_garbling/fancy/
crt.rs

1//! Module containing `CrtGadgets`, which are the CRT-based gadgets for `Fancy`.
2
3use super::{HasModulus, bundle::ArithmeticBundleGadgets};
4use crate::{
5    FancyArithmetic, FancyBinary,
6    fancy::bundle::{ArithmeticProjBundleGadgets, Bundle, BundleGadgets},
7    util,
8};
9use itertools::Itertools;
10use std::ops::Deref;
11use swanky_channel::Channel;
12
13/// Bundle which is explicitly CRT-representation.
14#[derive(Clone)]
15pub struct CrtBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> CrtBundle<W> {
18    /// Create a new CRT bundle from a vector of wires.
19    pub fn new(ws: Vec<W>) -> CrtBundle<W> {
20        CrtBundle(Bundle::new(ws))
21    }
22
23    /// Extract the underlying bundle from this CRT bundle.
24    pub fn extract(self) -> Bundle<W> {
25        self.0
26    }
27
28    /// Return the product of all the wires' moduli.
29    pub fn composite_modulus(&self) -> u128 {
30        util::product(&self.iter().map(HasModulus::modulus).collect_vec())
31    }
32}
33
34impl<W: Clone + HasModulus> Deref for CrtBundle<W> {
35    type Target = Bundle<W>;
36
37    fn deref(&self) -> &Bundle<W> {
38        &self.0
39    }
40}
41
42impl<W: Clone + HasModulus> From<Bundle<W>> for CrtBundle<W> {
43    fn from(b: Bundle<W>) -> CrtBundle<W> {
44        CrtBundle(b)
45    }
46}
47
48impl<F: FancyArithmetic + FancyBinary> CrtGadgets for F {}
49
50/// Extension trait for `Fancy` providing advanced CRT gadgets based on bundles of wires.
51pub trait CrtGadgets:
52    FancyArithmetic + FancyBinary + ArithmeticBundleGadgets + BundleGadgets
53{
54    /// Encode a CRT input bundle.
55    fn crt_encode(
56        &mut self,
57        value: u128,
58        modulus: u128,
59        channel: &mut Channel,
60    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
61        let qs = util::factor(modulus);
62        let xs = util::crt(value, &qs);
63        self.encode_bundle(&xs, &qs, channel).map(CrtBundle::from)
64    }
65
66    /// Receive an CRT input bundle.
67    fn crt_receive(
68        &mut self,
69        modulus: u128,
70        channel: &mut Channel,
71    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
72        let qs = util::factor(modulus);
73        self.receive_bundle(&qs, channel).map(CrtBundle::from)
74    }
75
76    /// Encode many CRT input bundles.
77    fn crt_encode_many(
78        &mut self,
79        values: &[u128],
80        modulus: u128,
81        channel: &mut Channel,
82    ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
83        let mods = util::factor(modulus);
84        let nmods = mods.len();
85        let xs = values
86            .iter()
87            .flat_map(|x| util::crt(*x, &mods))
88            .collect_vec();
89        let qs = itertools::repeat_n(mods, values.len())
90            .flatten()
91            .collect_vec();
92        let mut wires = self.encode_many(&xs, &qs, channel)?;
93        let buns = (0..values.len())
94            .map(|_| {
95                let ws = wires.drain(0..nmods).collect_vec();
96                CrtBundle::new(ws)
97            })
98            .collect_vec();
99        Ok(buns)
100    }
101
102    /// Receive many CRT input bundles.
103    fn crt_receive_many(
104        &mut self,
105        n: usize,
106        modulus: u128,
107        channel: &mut Channel,
108    ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
109        let mods = util::factor(modulus);
110        let nmods = mods.len();
111        let qs = itertools::repeat_n(mods, n).flatten().collect_vec();
112        let mut wires = self.receive_many(&qs, channel)?;
113        let buns = (0..n)
114            .map(|_| {
115                let ws = wires.drain(0..nmods).collect_vec();
116                CrtBundle::new(ws)
117            })
118            .collect_vec();
119        Ok(buns)
120    }
121
122    /// Creates a bundle of constant wires for the CRT representation of `x` under
123    /// composite modulus `q`.
124    fn crt_constant_bundle(
125        &mut self,
126        x: u128,
127        q: u128,
128        channel: &mut Channel,
129    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
130        let ps = util::factor(q);
131        let xs = ps.iter().map(|&p| (x % p as u128) as u16).collect_vec();
132        self.constant_bundle(&xs, &ps, channel).map(CrtBundle)
133    }
134
135    /// Output a CRT bundle and interpret it mod Q.
136    fn crt_output(
137        &mut self,
138        x: &CrtBundle<Self::Item>,
139        channel: &mut Channel,
140    ) -> swanky_error::Result<Option<u128>> {
141        let q = x.composite_modulus();
142        Ok(self
143            .output_bundle(x, channel)?
144            .map(|xs| util::crt_inv_factor(&xs, q)))
145    }
146
147    /// Output a slice of CRT bundles and interpret the outputs mod Q.
148    fn crt_outputs(
149        &mut self,
150        xs: &[CrtBundle<Self::Item>],
151        channel: &mut Channel,
152    ) -> swanky_error::Result<Option<Vec<u128>>> {
153        let mut zs = Vec::with_capacity(xs.len());
154        for x in xs.iter() {
155            let z = self.crt_output(x, channel)?;
156            zs.push(z);
157        }
158        Ok(zs.into_iter().collect())
159    }
160
161    ////////////////////////////////////////////////////////////////////////////////
162    // High-level computations dealing with bundles.
163
164    /// Add two CRT bundles.
165    fn crt_add(
166        &mut self,
167        x: &CrtBundle<Self::Item>,
168        y: &CrtBundle<Self::Item>,
169    ) -> CrtBundle<Self::Item> {
170        CrtBundle(self.add_bundles(x, y))
171    }
172
173    /// Subtract two CRT bundles.
174    fn crt_sub(
175        &mut self,
176        x: &CrtBundle<Self::Item>,
177        y: &CrtBundle<Self::Item>,
178    ) -> CrtBundle<Self::Item> {
179        CrtBundle(self.sub_bundles(x, y))
180    }
181
182    /// Multiplies each wire in `x` by the corresponding residue of `c`.
183    fn crt_cmul(&mut self, x: &CrtBundle<Self::Item>, c: u128) -> CrtBundle<Self::Item> {
184        let cs = util::crt(c, &x.moduli());
185        CrtBundle::new(
186            x.wires()
187                .iter()
188                .zip(cs)
189                .map(|(x, c)| self.cmul(x, c))
190                .collect::<Vec<Self::Item>>(),
191        )
192    }
193
194    /// Multiply `x` with `y`.
195    fn crt_mul(
196        &mut self,
197        x: &CrtBundle<Self::Item>,
198        y: &CrtBundle<Self::Item>,
199        channel: &mut Channel,
200    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
201        self.mul_bundles(x, y, channel).map(CrtBundle)
202    }
203}
204
205impl<F: ArithmeticProjBundleGadgets + CrtGadgets> CrtProjGadgets for F {}
206
207/// CRT gadgets that use projection gates.
208pub trait CrtProjGadgets: ArithmeticProjBundleGadgets + CrtGadgets {
209    /// Exponentiate `x` by the constant `c`.
210    fn crt_cexp(
211        &mut self,
212        x: &CrtBundle<Self::Item>,
213        c: u16,
214        channel: &mut Channel,
215    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
216        x.wires()
217            .iter()
218            .map(|x| {
219                let p = x.modulus();
220                let tab = (0..p)
221                    .map(|x| ((x as u64).pow(c as u32) % p as u64) as u16)
222                    .collect_vec();
223                self.proj(x, p, Some(tab), channel)
224            })
225            .collect::<swanky_error::Result<_>>()
226            .map(CrtBundle::new)
227    }
228
229    /// Compute the remainder with respect to modulus `p`.
230    ///
231    /// # Panics
232    /// Panics if `p` is not a modulus contained in `x`.
233    fn crt_rem(
234        &mut self,
235        x: &CrtBundle<Self::Item>,
236        p: u16,
237        channel: &mut Channel,
238    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
239        let i = x.moduli().iter().position(|&q| p == q);
240        assert!(i.is_some(), "`p` is not a modulus in the `x` bundle");
241        let i = i.unwrap();
242        let w = &x.wires()[i];
243        x.moduli()
244            .iter()
245            .map(|&q| self.mod_change(w, q, channel))
246            .collect::<swanky_error::Result<_>>()
247            .map(CrtBundle::new)
248    }
249
250    ////////////////////////////////////////////////////////////////////////////////
251    // Fancy functions based on Mike's fractional mixed radix trick.
252
253    /// Helper function for advanced gadgets, returns the MSB of the fractional part of
254    /// `X/M` where `M=product(ms)`.
255    fn crt_fractional_mixed_radix(
256        &mut self,
257        bun: &CrtBundle<Self::Item>,
258        ms: &[u16],
259        channel: &mut Channel,
260    ) -> swanky_error::Result<Self::Item> {
261        let ndigits = ms.len();
262
263        let q = util::product(&bun.moduli());
264        let M = util::product(ms);
265
266        let mut ds = Vec::new();
267
268        for wire in bun.wires().iter() {
269            let p = wire.modulus();
270
271            let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
272
273            for x in 0..p {
274                let crt_coef = util::inv(((q / p as u128) % p as u128) as i128, p as i128);
275                let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
276                let digits = util::as_mixed_radix(y, ms);
277                for i in 0..ndigits {
278                    tabs[i].push(digits[i]);
279                }
280            }
281
282            let new_ds = tabs
283                .into_iter()
284                .enumerate()
285                .map(|(i, tt)| self.proj(wire, ms[i], Some(tt), channel))
286                .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
287
288            ds.push(Bundle::new(new_ds));
289        }
290
291        self.mixed_radix_addition_msb_only(&ds, channel)
292    }
293
294    /// Compute `max(x,0)`.
295    ///
296    /// Optional output moduli.
297    fn crt_relu(
298        &mut self,
299        x: &CrtBundle<Self::Item>,
300        accuracy: &str,
301        output_moduli: Option<&[u16]>,
302        channel: &mut Channel,
303    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
304        let factors_of_m = &get_ms(x, accuracy);
305        let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
306
307        // project the MSB to 0/1, whether or not it is less than p/2
308        let p = *factors_of_m.last().unwrap();
309        let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect_vec();
310        let mask = self.proj(&res, 2, Some(mask_tt), channel)?;
311
312        // use the mask to either output x or 0
313        output_moduli
314            .map(|ps| x.with_moduli(ps))
315            .as_ref()
316            .unwrap_or(x)
317            .wires()
318            .iter()
319            .map(|x| self.mul(x, &mask, channel))
320            .collect::<swanky_error::Result<_>>()
321            .map(CrtBundle::new)
322    }
323
324    /// Return 0 if `x` is positive and 1 if `x` is negative.
325    fn crt_sign(
326        &mut self,
327        x: &CrtBundle<Self::Item>,
328        accuracy: &str,
329        channel: &mut Channel,
330    ) -> swanky_error::Result<Self::Item> {
331        let factors_of_m = &get_ms(x, accuracy);
332        let res = self.crt_fractional_mixed_radix(x, factors_of_m, channel)?;
333        let p = *factors_of_m.last().unwrap();
334        let tt = (0..p).map(|x| (x >= p / 2) as u16).collect_vec();
335        self.proj(&res, 2, Some(tt), channel)
336    }
337
338    /// Return `if x >= 0 then 1 else -1`, where `-1` is interpreted as `Q-1`.
339    ///
340    /// If provided, will produce a bundle under `output_moduli` instead of `x.moduli()`
341    fn crt_sgn(
342        &mut self,
343        x: &CrtBundle<Self::Item>,
344        accuracy: &str,
345        output_moduli: Option<&[u16]>,
346        channel: &mut Channel,
347    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
348        let sign = self.crt_sign(x, accuracy, channel)?;
349        output_moduli
350            .unwrap_or(&x.moduli())
351            .iter()
352            .map(|&p| {
353                let tt = vec![1, p - 1];
354                self.proj(&sign, p, Some(tt), channel)
355            })
356            .collect::<swanky_error::Result<_>>()
357            .map(CrtBundle::new)
358    }
359
360    /// Returns 1 if `x < y`.
361    fn crt_lt(
362        &mut self,
363        x: &CrtBundle<Self::Item>,
364        y: &CrtBundle<Self::Item>,
365        accuracy: &str,
366        channel: &mut Channel,
367    ) -> swanky_error::Result<Self::Item> {
368        let z = self.crt_sub(x, y);
369        self.crt_sign(&z, accuracy, channel)
370    }
371
372    /// Returns 1 if `x >= y`.
373    fn crt_geq(
374        &mut self,
375        x: &CrtBundle<Self::Item>,
376        y: &CrtBundle<Self::Item>,
377        accuracy: &str,
378        channel: &mut Channel,
379    ) -> swanky_error::Result<Self::Item> {
380        let z = self.crt_lt(x, y, accuracy, channel)?;
381        Ok(self.negate(&z))
382    }
383
384    /// Compute the maximum bundle in `xs`.
385    ///
386    /// # Panics
387    /// Panics if `xs` is empty.
388    fn crt_max(
389        &mut self,
390        xs: &[CrtBundle<Self::Item>],
391        accuracy: &str,
392        channel: &mut Channel,
393    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
394        assert!(!xs.is_empty(), "`xs` cannot be empty");
395        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
396            let pos = self.crt_lt(&x, y, accuracy, channel)?;
397            let neg = self.negate(&pos);
398            Ok(CrtBundle::new(
399                x.wires()
400                    .iter()
401                    .zip(y.wires().iter())
402                    .map(|(x, y)| {
403                        let xp = self.mul(x, &neg, channel)?;
404                        let yp = self.mul(y, &pos, channel)?;
405                        Ok(self.add(&xp, &yp))
406                    })
407                    .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
408            ))
409        })
410    }
411
412    /// Convert the xs bundle to PMR representation. Useful for extracting out of CRT.
413    fn crt_to_pmr(
414        &mut self,
415        xs: &CrtBundle<Self::Item>,
416        channel: &mut Channel,
417    ) -> swanky_error::Result<Bundle<Self::Item>> {
418        let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
419            let pq = p as u32 + q as u32 - 1;
420            let mut tab = Vec::with_capacity(pq as usize);
421            for z in 0..pq {
422                let mut x = 0;
423                let mut y = 0;
424                'outer: for i in 0..p as u32 {
425                    for j in 0..q as u32 {
426                        if (i + pq - j) % pq == z {
427                            x = i;
428                            y = j;
429                            break 'outer;
430                        }
431                    }
432                }
433                debug_assert_eq!((x + pq - y) % pq, z);
434                tab.push(
435                    (((x * q as u32 * util::inv(q as i128, p as i128) as u32
436                        + y * p as u32 * util::inv(p as i128, q as i128) as u32)
437                        / p as u32)
438                        % q as u32) as u16,
439                );
440            }
441            tab
442        };
443
444        let mut gadget = |x: &Self::Item, y: &Self::Item| -> swanky_error::Result<Self::Item> {
445            let p = x.modulus();
446            let q = y.modulus();
447            let x_ = self.mod_change(x, p + q - 1, channel)?;
448            let y_ = self.mod_change(y, p + q - 1, channel)?;
449            let z = self.sub(&x_, &y_);
450            self.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
451        };
452
453        let n = xs.size();
454        let mut x = vec![vec![None; n + 1]; n + 1];
455
456        for j in 0..n {
457            x[0][j + 1] = Some(xs.wires()[j].clone());
458        }
459
460        for i in 1..=n {
461            for j in i + 1..=n {
462                let z = gadget(x[i - 1][i].as_ref().unwrap(), x[i - 1][j].as_ref().unwrap())?;
463                x[i][j] = Some(z);
464            }
465        }
466
467        let mut zwires = Vec::with_capacity(n);
468        for i in 0..n {
469            zwires.push(x[i][i + 1].take().unwrap());
470        }
471        Ok(Bundle::new(zwires))
472    }
473
474    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
475    /// it to work, there must be an extra modulus in the CRT that is not necessary to
476    /// represent the values. This ensures that if x < y, the most significant PMR digit
477    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
478    /// before using this gadget.
479    fn pmr_lt(
480        &mut self,
481        x: &CrtBundle<Self::Item>,
482        y: &CrtBundle<Self::Item>,
483        channel: &mut Channel,
484    ) -> swanky_error::Result<Self::Item> {
485        let z = self.crt_sub(x, y);
486        let mut pmr = self.crt_to_pmr(&z, channel)?;
487        let w = pmr.pop().unwrap();
488        let mut tab = vec![1; w.modulus() as usize];
489        tab[0] = 0;
490        self.proj(&w, 2, Some(tab), channel)
491    }
492
493    /// Comparison based on PMR, more expensive than crt_lt but works on more things. For
494    /// it to work, there must be an extra modulus in the CRT that is not necessary to
495    /// represent the values. This ensures that if x < y, the most significant PMR digit
496    /// is nonzero after subtracting them. You could add a prime to your CrtBundles right
497    /// before using this gadget.
498    fn pmr_geq(
499        &mut self,
500        x: &CrtBundle<Self::Item>,
501        y: &CrtBundle<Self::Item>,
502        channel: &mut Channel,
503    ) -> swanky_error::Result<Self::Item> {
504        let z = self.pmr_lt(x, y, channel)?;
505        Ok(self.negate(&z))
506    }
507
508    /// Generic, and expensive, CRT-based addition for two ciphertexts. Uses PMR
509    /// comparison repeatedly. Requires an extra unused prime in both inputs.
510    ///
511    /// # Panics
512    /// Panics if `x` and `y` do not have equal moduli.
513    fn crt_div(
514        &mut self,
515        x: &CrtBundle<Self::Item>,
516        y: &CrtBundle<Self::Item>,
517        channel: &mut Channel,
518    ) -> swanky_error::Result<CrtBundle<Self::Item>> {
519        assert_eq!(x.moduli(), y.moduli());
520
521        let q = x.composite_modulus();
522
523        // Compute l based on the assumption that the last prime is unused.
524        let nprimes = x.moduli().len();
525        let qs_ = &x.moduli()[..nprimes - 1];
526        let q_ = util::product(qs_);
527        let l = 128 - q_.leading_zeros();
528
529        let mut quotient = self.crt_constant_bundle(0, q, channel)?;
530        let mut a = x.clone();
531
532        let one = self.crt_constant_bundle(1, q, channel)?;
533        for i in 0..l {
534            let b = 2u128.pow(l - i - 1);
535            let mut pb = q_ / b;
536            if q_.is_multiple_of(b) {
537                pb -= 1;
538            }
539
540            let tmp = self.crt_cmul(y, b);
541            let c1 = self.pmr_geq(&a, &tmp, channel)?;
542
543            let pb_crt = self.crt_constant_bundle(pb, q, channel)?;
544            let c2 = self.pmr_geq(&pb_crt, y, channel)?;
545
546            let c = self.and(&c1, &c2, channel)?;
547
548            let c_ws = one
549                .iter()
550                .map(|w| self.mul(w, &c, channel))
551                .collect::<Result<Vec<_>, _>>()?;
552            let c_crt = CrtBundle::new(c_ws);
553
554            let b_if = self.crt_cmul(&c_crt, b);
555            quotient = self.crt_add(&quotient, &b_if);
556
557            let tmp_if = self.crt_mul(&c_crt, &tmp, channel)?;
558            a = self.crt_sub(&a, &tmp_if);
559        }
560
561        Ok(quotient)
562    }
563}
564
565/// Compute the `ms` needed for the number of CRT primes in `x`, with accuracy
566/// `accuracy`.
567///
568/// Supported accuracy: ["100%", "99.9%", "99%"]
569fn get_ms<W: Clone + HasModulus>(x: &Bundle<W>, accuracy: &str) -> Vec<u16> {
570    match accuracy {
571        "100%" => match x.moduli().len() {
572            3 => vec![2; 5],
573            4 => vec![3, 26],
574            5 => vec![3, 4, 54],
575            6 => vec![5, 5, 5, 60],
576            7 => vec![5, 6, 6, 7, 86],
577            8 => vec![5, 7, 8, 8, 9, 98],
578            9 => vec![5, 5, 7, 7, 7, 7, 7, 76],
579            10 => vec![5, 5, 6, 6, 6, 6, 11, 11, 202],
580            11 => vec![5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 150],
581            n => panic!("unknown exact Ms for {} primes!", n),
582        },
583        "99.999%" => match x.moduli().len() {
584            8 => vec![5, 5, 6, 7, 102],
585            9 => vec![5, 5, 6, 7, 114],
586            10 => vec![5, 6, 6, 7, 102],
587            11 => vec![5, 5, 6, 7, 130],
588            n => panic!("unknown 99.999% accurate Ms for {} primes!", n),
589        },
590        "99.99%" => match x.moduli().len() {
591            6 => vec![5, 5, 5, 42],
592            7 => vec![4, 5, 6, 88],
593            8 => vec![4, 5, 7, 78],
594            9 => vec![5, 5, 6, 84],
595            10 => vec![4, 5, 6, 112],
596            11 => vec![7, 11, 174],
597            n => panic!("unknown 99.99% accurate Ms for {} primes!", n),
598        },
599        "99.9%" => match x.moduli().len() {
600            5 => vec![3, 5, 30],
601            6 => vec![4, 5, 48],
602            7 => vec![4, 5, 60],
603            8 => vec![3, 5, 78],
604            9 => vec![9, 140],
605            10 => vec![7, 190],
606            n => panic!("unknown 99.9% accurate Ms for {} primes!", n),
607        },
608        "99%" => match x.moduli().len() {
609            4 => vec![3, 18],
610            5 => vec![3, 36],
611            6 => vec![3, 40],
612            7 => vec![3, 40],
613            8 => vec![126],
614            9 => vec![138],
615            10 => vec![140],
616            n => panic!("unknown 99% accurate Ms for {} primes!", n),
617        },
618        _ => panic!("get_ms: unsupported accuracy {}", accuracy),
619    }
620}