Skip to main content

fancy_garbling/circuits/arithmetic/
pmr.rs

1use crate::{
2    CrtBundle, FancyArithmetic, FancyBinary, FancyProj, HasModulus,
3    circuit::Circuit,
4    circuits::arithmetic::{ModChange, Subtraction},
5    util::inv,
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11/// Convert a [`CrtBundle`] `x` to PMR representation.
12#[derive(Default)]
13pub struct ToPmr<'a>(PhantomData<&'a ()>);
14
15impl<'a> ToPmr<'a> {
16    /// Create a new [`ToPmr`] circuit.
17    pub fn new() -> Self {
18        Default::default()
19    }
20}
21
22impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for ToPmr<'a>
23where
24    F::Item: 'a,
25{
26    type Input = &'a CrtBundle<F::Item>;
27    type Output = CrtBundle<F::Item>;
28
29    fn execute(
30        &self,
31        backend: &mut F,
32        inputs: Self::Input,
33        channel: &mut Channel,
34    ) -> Result<Self::Output> {
35        let xs = inputs;
36        let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
37            let pq = p as u32 + q as u32 - 1;
38            let mut tab = Vec::with_capacity(pq as usize);
39            for z in 0..pq {
40                let mut x = 0;
41                let mut y = 0;
42                'outer: for i in 0..p as u32 {
43                    for j in 0..q as u32 {
44                        if (i + pq - j) % pq == z {
45                            x = i;
46                            y = j;
47                            break 'outer;
48                        }
49                    }
50                }
51                debug_assert_eq!((x + pq - y) % pq, z);
52                tab.push(
53                    (((x * q as u32 * inv(q as i128, p as i128) as u32
54                        + y * p as u32 * inv(p as i128, q as i128) as u32)
55                        / p as u32)
56                        % q as u32) as u16,
57                );
58            }
59            tab
60        };
61
62        let mut gadget = |x: F::Item, y: F::Item| -> Result<F::Item> {
63            let p = x.modulus();
64            let q = y.modulus();
65            let x_ = ModChange.execute(backend, (x, p + q - 1), channel)?;
66            let y_ = ModChange.execute(backend, (y, p + q - 1), channel)?;
67            let z = backend.sub(&x_, &y_);
68            backend.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
69        };
70
71        let n = xs.size();
72        let mut x = vec![vec![None; n + 1]; n + 1];
73
74        for j in 0..n {
75            x[0][j + 1] = Some(xs.wires()[j].clone());
76        }
77
78        for i in 1..=n {
79            for j in i + 1..=n {
80                let z = gadget(x[i - 1][i].clone().unwrap(), x[i - 1][j].clone().unwrap())?;
81                x[i][j] = Some(z);
82            }
83        }
84
85        let mut zwires = Vec::with_capacity(n);
86        for i in 0..n {
87            zwires.push(x[i][i + 1].take().unwrap());
88        }
89        Ok(CrtBundle::new(zwires))
90    }
91}
92
93/// For [`CrtBundle`]s `x` and `y`, output `x < y` using PMR representation.
94///
95/// For this to work, there must be an extra modulus in the CRT that is not
96/// necessary to represent the values. This ensures that if `x < y`, the most
97/// significant PMR digit is nonzero after subtracting them. You could add a
98/// prime to your [`CrtBundle`]s right before using this gadget.
99#[derive(Default)]
100pub struct PmrLessThan<'a>(PhantomData<&'a ()>);
101
102impl<'a> PmrLessThan<'a> {
103    /// Create a new [`PmrLessThan`] circuit.
104    pub fn new() -> Self {
105        Default::default()
106    }
107}
108
109impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for PmrLessThan<'a>
110where
111    F::Item: 'a,
112{
113    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
114    type Output = F::Item;
115
116    fn execute(
117        &self,
118        backend: &mut F,
119        inputs: Self::Input,
120        channel: &mut Channel,
121    ) -> Result<Self::Output> {
122        let z = Subtraction::new().execute(backend, inputs, channel)?;
123        let mut pmr = ToPmr::new().execute(backend, &z, channel)?;
124        let w = pmr.pop().unwrap();
125        let mut tab = vec![1; w.modulus() as usize];
126        tab[0] = 0;
127        backend.proj(&w, 2, Some(tab), channel)
128    }
129}
130
131/// For [`CrtBundle`]s `x` and `y`, output `x >= y` using PMR representation.
132///
133/// For this to work, there must be an extra modulus in the CRT that is not
134/// necessary to represent the values. This ensures that if `x >= y`, the most
135/// significant PMR digit is nonzero after subtracting them. You could add a
136/// prime to your [`CrtBundle`]s right before using this gadget.
137#[derive(Default)]
138pub struct PmrGreaterThanOrEqual<'a>(PhantomData<&'a ()>);
139
140impl<'a> PmrGreaterThanOrEqual<'a> {
141    /// Create a new [`PmrGreaterThanOrEqual`] circuit.
142    pub fn new() -> Self {
143        Default::default()
144    }
145}
146
147impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for PmrGreaterThanOrEqual<'a>
148where
149    F::Item: 'a,
150{
151    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
152    type Output = F::Item;
153
154    fn execute(
155        &self,
156        backend: &mut F,
157        inputs: Self::Input,
158        channel: &mut Channel,
159    ) -> Result<Self::Output> {
160        let z = PmrLessThan::new().execute(backend, inputs, channel)?;
161        Ok(backend.negate(&z))
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use rand::Rng;
168
169    use crate::{
170        circuits::arithmetic::{
171            ToPmr,
172            pmr::{PmrGreaterThanOrEqual, PmrLessThan},
173        },
174        dummy::{Dummy, DummyVal},
175        util::{RngExt, product},
176    };
177
178    #[test]
179    fn to_pmr() {
180        fn to_pmr_pt(x: u128, ps: &[u16]) -> Vec<u16> {
181            let mut ds = vec![0; ps.len()];
182            let mut q = 1;
183            for i in 0..ps.len() {
184                let p = ps[i] as u128;
185                ds[i] = ((x / q) % p) as u16;
186                q *= p;
187            }
188            ds
189        }
190
191        let mut rng = rand::thread_rng();
192        for _ in 0..8 {
193            let ps = rng.gen_usable_factors();
194            let q = product(&ps);
195
196            let x = rng.r#gen::<u128>() % q;
197            let expected = to_pmr_pt(x, &ps);
198
199            let x_input = DummyVal::to_crt(x, q);
200            let z = Dummy::eval(&ToPmr::new(), &x_input).unwrap();
201            let output = z.wires().iter().map(|w| w.val()).collect::<Vec<_>>();
202            assert_eq!(output, expected);
203        }
204    }
205
206    #[test]
207    fn pmr_less_than() {
208        let mut rng = rand::thread_rng();
209        for _ in 0..8 {
210            let qs = rng.gen_usable_factors();
211            let n = qs.len();
212            let q = product(&qs);
213            let q_ = product(&qs[..n - 1]);
214            let x = rng.r#gen::<u128>() % q_;
215            let y = rng.r#gen::<u128>() % q_;
216
217            let x_input = DummyVal::to_crt(x, q);
218            let y_input = DummyVal::to_crt(y, q);
219            let output = Dummy::eval(&PmrLessThan::new(), (&x_input, &y_input)).unwrap();
220            assert_eq!(output.val(), (x < y) as u16);
221        }
222    }
223
224    #[test]
225    fn pmr_greater_than_or_equal() {
226        let mut rng = rand::thread_rng();
227        for _ in 0..8 {
228            let qs = rng.gen_usable_factors();
229            let n = qs.len();
230            let q = product(&qs);
231            let q_ = product(&qs[..n - 1]);
232            let x = rng.gen_u128() % q_;
233            let y = rng.gen_u128() % q_;
234
235            let x_input = DummyVal::to_crt(x, q);
236            let y_input = DummyVal::to_crt(y, q);
237            let output = Dummy::eval(&PmrGreaterThanOrEqual::new(), (&x_input, &y_input)).unwrap();
238            assert_eq!(output.val(), (x >= y) as u16);
239        }
240    }
241}