Skip to main content

fancy_garbling/circuits/arithmetic/
equality.rs

1use crate::{
2    CrtBundle, FancyArithmetic, FancyProj, HasModulus, circuit::Circuit,
3    circuits::arithmetic::addition::AddMany,
4};
5use core::marker::PhantomData;
6use swanky_channel::Channel;
7use swanky_error::Result;
8
9/// For [`CrtBundle`]s `x` and `y`, output `x == y`.
10#[derive(Default)]
11pub struct Equality<'a>(PhantomData<&'a ()>);
12
13impl<'a> Equality<'a> {
14    /// Create a new [`Equality`] circuit.
15    pub fn new() -> Self {
16        Default::default()
17    }
18}
19
20impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for Equality<'a>
21where
22    F::Item: 'a,
23{
24    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
25    type Output = F::Item;
26
27    fn execute(
28        &self,
29        backend: &mut F,
30        inputs: Self::Input,
31        channel: &mut Channel,
32    ) -> Result<Self::Output> {
33        let (x, y) = inputs;
34        assert_eq!(x.moduli(), y.moduli());
35
36        let wlen = x.wires().len() as u16;
37        let zs = x
38            .wires()
39            .iter()
40            .zip(y.wires().iter())
41            .map(|(x, y)| {
42                // compute (x-y == 0) for each residue
43                let z = backend.sub(x, y);
44                let mut eq_zero_tab = vec![0; x.modulus() as usize];
45                eq_zero_tab[0] = 1;
46                backend.proj(&z, wlen + 1, Some(eq_zero_tab), channel)
47            })
48            .collect::<Result<Vec<_>>>()?;
49        // add up the results, and output whether they equal zero or not, mod 2
50        let z = AddMany::new().execute(backend, zs.as_slice(), channel)?;
51        let b = zs.len();
52        let mut tab = vec![0; b + 1];
53        tab[b] = 1;
54        backend.proj(&z, 2, Some(tab), channel)
55    }
56}
57
58#[cfg(test)]
59mod test {
60    use crate::{
61        circuits::arithmetic::Equality,
62        dummy::{Dummy, DummyVal},
63        util::RngExt,
64    };
65    use rand::{Rng, thread_rng};
66
67    #[test]
68    fn equality() {
69        let mut rng = thread_rng();
70        let q = rng.gen_usable_composite_modulus();
71
72        // Check that `x == x`.
73        let x = rng.r#gen::<u128>() % q;
74        let x_input = DummyVal::to_crt(x, q);
75        let output = Dummy::eval(&Equality::new(), (&x_input, &x_input)).unwrap();
76        assert_eq!(output.val(), (x == x) as u16);
77
78        for _ in 0..64 {
79            let x = rng.r#gen::<u128>() % q;
80            let y = rng.r#gen::<u128>() % q;
81            let x_input = DummyVal::to_crt(x, q);
82            let y_input = DummyVal::to_crt(y, q);
83            let output = Dummy::eval(&Equality::new(), (&x_input, &y_input)).unwrap();
84            assert_eq!(output.val(), (x == y) as u16);
85        }
86    }
87}