Skip to main content

fancy_garbling/circuits/arithmetic/
comparison.rs

1use crate::{
2    CrtBundle, FancyArithmetic, FancyBinary, FancyProj,
3    circuit::Circuit,
4    circuits::arithmetic::{FractionalMixedRadix, Subtraction},
5    util::get_ms,
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11/// For [`CrtBundle`] `x`, return 0 if `x >= 0`, 1 otherwise.
12#[derive(Default)]
13pub struct Sign<'a>(PhantomData<&'a ()>);
14
15impl<'a> Sign<'a> {
16    /// Create a new [`Sign`] circuit.
17    pub fn new() -> Self {
18        Default::default()
19    }
20}
21
22impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for Sign<'a>
23where
24    F::Item: 'a,
25{
26    type Input = (&'a CrtBundle<F::Item>, &'a str);
27    type Output = F::Item;
28
29    fn execute(
30        &self,
31        backend: &mut F,
32        inputs: Self::Input,
33        channel: &mut Channel,
34    ) -> Result<Self::Output> {
35        let (x, accuracy) = inputs;
36        let factors_of_m = get_ms(x, accuracy);
37        let res = FractionalMixedRadix::new().execute(backend, (x, &factors_of_m), channel)?;
38        let p = factors_of_m.last().unwrap();
39        let tt = (0..*p).map(|x| (x >= p / 2) as u16).collect::<Vec<_>>();
40        backend.proj(&res, 2, Some(tt), channel)
41    }
42}
43
44/// For [`CrtBundle`]s `x` and `y`, return `x < y`.
45#[derive(Default)]
46pub struct LessThan<'a>(PhantomData<&'a ()>);
47
48impl<'a> LessThan<'a> {
49    /// Create a new [`LessThan`] circuit.
50    pub fn new() -> Self {
51        Default::default()
52    }
53}
54
55impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for LessThan<'a>
56where
57    F::Item: 'a,
58{
59    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>, &'a str);
60    type Output = F::Item;
61
62    fn execute(
63        &self,
64        backend: &mut F,
65        inputs: Self::Input,
66        channel: &mut Channel,
67    ) -> Result<Self::Output> {
68        let (x, y, accuracy) = inputs;
69        let z = Subtraction::new().execute(backend, (x, y), channel)?;
70        Sign::new().execute(backend, (&z, accuracy), channel)
71    }
72}
73
74/// For [`CrtBundle`]s `x` and `y`, return `x >= y`.
75#[derive(Default)]
76pub struct GreaterThanOrEqual<'a>(PhantomData<&'a ()>);
77
78impl<'a> GreaterThanOrEqual<'a> {
79    /// Create a new [`GreaterThanOrEqual`] circuit.
80    pub fn new() -> Self {
81        Default::default()
82    }
83}
84
85impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for GreaterThanOrEqual<'a>
86where
87    F::Item: 'a,
88{
89    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>, &'a str);
90    type Output = F::Item;
91
92    fn execute(
93        &self,
94        backend: &mut F,
95        inputs: Self::Input,
96        channel: &mut Channel,
97    ) -> Result<Self::Output> {
98        let z = LessThan::new().execute(backend, inputs, channel)?;
99        Ok(backend.negate(&z))
100    }
101}
102
103/// For a vector of [`CrtBundle`]s `xs`, return `max(xs)`.
104#[derive(Default)]
105pub struct Max<'a>(PhantomData<&'a ()>);
106
107impl<'a> Max<'a> {
108    /// Create a new [`Max`] circuit.
109    pub fn new() -> Self {
110        Default::default()
111    }
112}
113
114impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for Max<'a>
115where
116    F::Item: 'a,
117{
118    type Input = (&'a [CrtBundle<F::Item>], &'a str);
119    type Output = CrtBundle<F::Item>;
120
121    fn execute(
122        &self,
123        backend: &mut F,
124        inputs: Self::Input,
125        channel: &mut Channel,
126    ) -> Result<Self::Output> {
127        let (xs, accuracy) = inputs;
128        assert!(!xs.is_empty(), "`xs` cannot be empty");
129
130        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
131            let pos = LessThan::new().execute(backend, (&x, y, accuracy), channel)?;
132            let neg = backend.negate(&pos);
133            Ok(CrtBundle::new(
134                x.wires()
135                    .iter()
136                    .zip(y.wires().iter())
137                    .map(|(x, y)| {
138                        let xp = backend.mul(x, &neg, channel)?;
139                        let yp = backend.mul(y, &pos, channel)?;
140                        Ok(backend.add(&xp, &yp))
141                    })
142                    .collect::<Result<Vec<_>>>()?,
143            ))
144        })
145    }
146}
147
148/// For [`CrtBundle`] `x`, if `x >= 0` return `1`, otherwise return `-1`, where
149/// `-1` is interpreted as `Q - 1` and `Q` is the modulus of `x`.
150///
151/// If `output_moduli` is provided, output the result using the provided moduli.
152#[derive(Default)]
153pub struct Sgn<'a>(PhantomData<&'a ()>);
154
155impl<'a> Sgn<'a> {
156    /// Create a new [`Sgn`] circuit.
157    pub fn new() -> Self {
158        Default::default()
159    }
160}
161
162impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for Sgn<'a>
163where
164    F::Item: 'a,
165{
166    type Input = (&'a CrtBundle<F::Item>, &'a str, Option<&'a [u16]>);
167    type Output = CrtBundle<F::Item>;
168
169    fn execute(
170        &self,
171        backend: &mut F,
172        inputs: Self::Input,
173        channel: &mut Channel,
174    ) -> Result<Self::Output> {
175        let (x, accuracy, output_moduli) = inputs;
176        let sign = Sign::new().execute(backend, (x, accuracy), channel)?;
177        output_moduli
178            .map(|m| m.to_vec())
179            .unwrap_or_else(|| x.moduli())
180            .iter()
181            .map(|&p| {
182                let tt = vec![1, p - 1];
183                backend.proj(&sign, p, Some(tt), channel)
184            })
185            .collect::<Result<_>>()
186            .map(CrtBundle::new)
187    }
188}
189
190/// For [`CrtBundle`] `x`, output `max(0, x)`.
191#[derive(Default)]
192pub struct ReLU<'a>(PhantomData<&'a ()>);
193
194impl<'a> ReLU<'a> {
195    /// Create a new [`ReLU`] circuit.
196    pub fn new() -> Self {
197        Default::default()
198    }
199}
200
201impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for ReLU<'a>
202where
203    F::Item: 'a,
204{
205    type Input = (&'a CrtBundle<F::Item>, &'a str, Option<&'a [u16]>);
206    type Output = CrtBundle<F::Item>;
207
208    fn execute(
209        &self,
210        backend: &mut F,
211        inputs: Self::Input,
212        channel: &mut Channel,
213    ) -> Result<Self::Output> {
214        let (x, accuracy, output_moduli) = inputs;
215        let factors_of_m = get_ms(x, accuracy);
216        let res = FractionalMixedRadix::new().execute(backend, (x, &factors_of_m), channel)?;
217
218        // project the MSB to 0/1, whether or not it is less than p/2
219        let p = *factors_of_m.last().unwrap();
220        let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect::<Vec<_>>();
221        let mask = backend.proj(&res, 2, Some(mask_tt), channel)?;
222
223        // use the mask to either output x or 0
224        let output_bundle = match output_moduli {
225            Some(ps) => x.with_moduli(ps),
226            None => (*x).clone().extract(),
227        };
228
229        output_bundle
230            .wires()
231            .iter()
232            .map(|x| backend.mul(x, &mask, channel))
233            .collect::<Result<_>>()
234            .map(CrtBundle::new)
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use rand::{Rng, thread_rng};
241
242    use crate::{
243        circuits::arithmetic::{
244            ReLU, Sgn,
245            comparison::{GreaterThanOrEqual, LessThan, Max, Sign},
246        },
247        dummy::{Dummy, DummyVal},
248        util::modulus_with_width,
249    };
250
251    #[test]
252    fn sign() {
253        let mut rng = thread_rng();
254        let accuracy = "100%";
255        let q = modulus_with_width(10);
256
257        // Check that `Sign(0) == 0`.
258        let x = 0;
259        let x_input = DummyVal::to_crt(x, q);
260        let output = Dummy::eval(&Sign::new(), (&x_input, accuracy)).unwrap();
261        assert_eq!(output.val(), if x < q / 2 { 0 } else { 1 });
262
263        for _ in 0..64 {
264            let x = rng.r#gen::<u128>() % q;
265            let x_input = DummyVal::to_crt(x, q);
266            let output = Dummy::eval(&Sign::new(), (&x_input, accuracy)).unwrap();
267            assert_eq!(output.val(), if x < q / 2 { 0 } else { 1 });
268        }
269    }
270
271    #[test]
272    fn less_than() {
273        let mut rng = thread_rng();
274        let accuracy = "100%";
275        let q = modulus_with_width(10);
276
277        // Check that `x < x` works.
278        let x = rng.r#gen::<u128>() % q / 2;
279        let x_input = DummyVal::to_crt(x, q);
280        let output = Dummy::eval(&LessThan::new(), (&x_input, &x_input, accuracy)).unwrap();
281        assert_eq!(output.val(), (x < x) as u16);
282
283        for _ in 0..64 {
284            let x = rng.r#gen::<u128>() % q / 2;
285            let y = rng.r#gen::<u128>() % q / 2;
286            let x_input = DummyVal::to_crt(x, q);
287            let y_input = DummyVal::to_crt(y, q);
288            let output = Dummy::eval(&LessThan::new(), (&x_input, &y_input, accuracy)).unwrap();
289            assert_eq!(output.val(), (x < y) as u16);
290        }
291    }
292
293    #[test]
294    fn greater_than_or_equal() {
295        let mut rng = thread_rng();
296        let accuracy = "100%";
297        let q = modulus_with_width(10);
298
299        // Check that `x >= x` works.
300        let x = rng.r#gen::<u128>() % q / 2;
301        let x_input = DummyVal::to_crt(x, q);
302        let output =
303            Dummy::eval(&GreaterThanOrEqual::new(), (&x_input, &x_input, accuracy)).unwrap();
304        assert_eq!(output.val(), (x >= x) as u16);
305
306        for _ in 0..64 {
307            let x = rng.r#gen::<u128>() % q / 2;
308            let y = rng.r#gen::<u128>() % q / 2;
309            let x_input = DummyVal::to_crt(x, q);
310            let y_input = DummyVal::to_crt(y, q);
311            let output =
312                Dummy::eval(&GreaterThanOrEqual::new(), (&x_input, &y_input, accuracy)).unwrap();
313            assert_eq!(output.val(), (x >= y) as u16);
314        }
315    }
316
317    #[test]
318    fn max() {
319        let mut rng = thread_rng();
320        let accuracy = "100%";
321        let q = modulus_with_width(10);
322
323        for _ in 0..16 {
324            let inputs = (0..100)
325                .map(|_| rng.r#gen::<u128>() % (q / 2))
326                .collect::<Vec<_>>();
327            let expected = *inputs.iter().max().unwrap();
328
329            let inputs = inputs
330                .into_iter()
331                .map(|x| DummyVal::to_crt(x, q))
332                .collect::<Vec<_>>();
333            let z = Dummy::eval(&Max::new(), (&inputs[..], accuracy)).unwrap();
334            let output = DummyVal::from_crt(&z, q);
335            assert_eq!(output, expected);
336        }
337    }
338
339    #[test]
340    fn sgn() {
341        let mut rng = thread_rng();
342        let accuracy = "100%";
343        let q = modulus_with_width(10);
344
345        // Check that `Sign(0) == 1`.
346        let x = 0;
347        let x_input = DummyVal::to_crt(x, q);
348        let z = Dummy::eval(&Sgn::new(), (&x_input, accuracy, None)).unwrap();
349        let output = DummyVal::from_crt(&z, q);
350        assert_eq!(output, if x < q / 2 { 1 } else { q - 1 });
351
352        for _ in 0..64 {
353            let x = rng.r#gen::<u128>() % q;
354            let x_input = DummyVal::to_crt(x, q);
355            let z = Dummy::eval(&Sgn::new(), (&x_input, accuracy, None)).unwrap();
356            let output = DummyVal::from_crt(&z, q);
357            assert_eq!(output, if x < q / 2 { 1 } else { q - 1 });
358        }
359    }
360
361    #[test]
362    fn relu() {
363        let mut rng = thread_rng();
364        let accuracy = "100%";
365        let q = modulus_with_width(10);
366
367        // Check that `Sign(0) == 1`.
368        let x = 0;
369        let x_input = DummyVal::to_crt(x, q);
370        let z = Dummy::eval(&ReLU::new(), (&x_input, accuracy, None)).unwrap();
371        let output = DummyVal::from_crt(&z, q);
372        assert_eq!(output, if x < q / 2 { x } else { 0 });
373
374        for _ in 0..64 {
375            let x = rng.r#gen::<u128>() % q;
376            let x_input = DummyVal::to_crt(x, q);
377            let z = Dummy::eval(&ReLU::new(), (&x_input, accuracy, None)).unwrap();
378            let output = DummyVal::from_crt(&z, q);
379            assert_eq!(output, if x < q / 2 { x } else { 0 });
380        }
381    }
382}