Skip to main content

fancy_garbling/circuits/arithmetic/
mask.rs

1use crate::{CrtBundle, FancyArithmetic, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// Given a wire `b` and a [`CrtBundle`] `x`, output `0` if `b == 0`, otherwise
7/// output `x`.
8///
9/// This is equivalent to computing `b * x` for each wire in the bundle.
10#[derive(Default)]
11pub struct Mask<'a>(PhantomData<&'a ()>);
12
13impl<'a> Mask<'a> {
14    /// Create a new [`Mask`] circuit.
15    pub fn new() -> Self {
16        Default::default()
17    }
18}
19
20impl<'a, F: FancyArithmetic> Circuit<F> for Mask<'a>
21where
22    F::Item: 'a,
23{
24    type Input = (&'a F::Item, &'a CrtBundle<F::Item>);
25    type Output = CrtBundle<F::Item>;
26
27    fn execute(
28        &self,
29        backend: &mut F,
30        inputs: Self::Input,
31        channel: &mut Channel,
32    ) -> Result<Self::Output> {
33        let (b, x) = inputs;
34        Ok(CrtBundle::new(
35            x.wires()
36                .iter()
37                .map(|xwire| backend.mul(xwire, b, channel))
38                .collect::<Result<_>>()?,
39        ))
40    }
41}
42
43#[cfg(test)]
44mod test {
45    use crate::{
46        circuits::arithmetic::Mask,
47        dummy::{Dummy, DummyVal},
48        util::RngExt,
49    };
50    use rand::{Rng, thread_rng};
51
52    #[test]
53    fn mask() {
54        let mut rng = thread_rng();
55        let q = rng.gen_usable_composite_modulus();
56
57        for _ in 0..16 {
58            let b = rng.r#gen::<bool>();
59            let x = rng.r#gen::<u128>() % q;
60
61            let b_input = DummyVal::new(b as u16, 2);
62            let x_input = DummyVal::to_crt(x, q);
63
64            let circuit = Mask::new();
65            let z = Dummy::eval(&circuit, (&b_input, &x_input)).unwrap();
66            let output = DummyVal::from_crt(&z, q);
67
68            assert_eq!(output, (b as u128) * x);
69        }
70    }
71}