Skip to main content

fancy_garbling/circuits/arithmetic/
remainder.rs

1use crate::{CrtBundle, FancyProj, circuit::Circuit, circuits::arithmetic::ModChange};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// Given a [`CrtBundle`] `x` and modulus `p`, compute the remainder with respect to `p`.
7///
8/// # Panics
9/// Panics if `p` is not a modulus contained in `x`.
10#[derive(Default)]
11pub struct Remainder<'a>(PhantomData<&'a ()>);
12
13impl<'a> Remainder<'a> {
14    /// Create a new [`Remainder`] circuit.
15    pub fn new() -> Self {
16        Default::default()
17    }
18}
19
20impl<'a, F: FancyProj> Circuit<F> for Remainder<'a>
21where
22    F::Item: 'a,
23{
24    type Input = (&'a CrtBundle<F::Item>, u16);
25    type Output = CrtBundle<F::Item>;
26
27    fn execute(
28        &self,
29        backend: &mut F,
30        inputs: Self::Input,
31        channel: &mut Channel,
32    ) -> Result<Self::Output> {
33        let (x, modulus) = inputs;
34        let i = x.moduli().iter().position(|&q| modulus == q);
35        assert!(
36            i.is_some(),
37            "`modulus` {modulus} is not in the input bundle",
38        );
39        let i = i.unwrap();
40        let w = &x.wires()[i];
41
42        // Convert the wire modulo `modulus` to all the other moduli in the bundle.
43        x.moduli()
44            .iter()
45            .map(|&q| ModChange.execute(backend, (w.clone(), q), channel))
46            .collect::<Result<_>>()
47            .map(CrtBundle::new)
48    }
49}
50
51#[cfg(test)]
52mod test {
53    use crate::{
54        circuits::arithmetic::Remainder,
55        dummy::{Dummy, DummyVal},
56        util::{RngExt, factor},
57    };
58    use rand::{Rng, thread_rng};
59
60    #[test]
61    fn remainder() {
62        let mut rng = thread_rng();
63        let q = rng.gen_usable_composite_modulus();
64        let factors = factor(q);
65
66        for _ in 0..16 {
67            let x = rng.r#gen::<u64>() as u128 % q;
68            let p = factors[rng.gen_range(0..factors.len())];
69
70            let x_input = DummyVal::to_crt(x, q);
71            let circuit = Remainder::new();
72            let z = Dummy::eval(&circuit, (&x_input, p)).unwrap();
73            let output = DummyVal::from_crt(&z, q);
74
75            assert_eq!(output, x % (p as u128));
76        }
77    }
78}