Skip to main content

fancy_garbling/circuits/arithmetic/
division.rs

1use crate::{
2    CrtBundle, CrtGadgets, FancyArithmetic, FancyBinary, FancyProj,
3    circuit::Circuit,
4    circuits::arithmetic::{
5        Addition, Constant, ConstantMultiplication, Multiplication, PmrGreaterThanOrEqual,
6        Subtraction,
7    },
8    util::product,
9};
10use core::marker::PhantomData;
11use swanky_channel::Channel;
12use swanky_error::Result;
13
14/// For [`CrtBundle`]s `x` and `y`, output `x / y`.
15///
16/// The inputs are required to have an extra (unused) prime. That is, for
17/// modulus $`Q = \prod_{i = 1,...,n} q_i`$, plaintext inputs `x` and `y` must
18/// be modulo $`Q / q_n`$.
19#[derive(Default)]
20pub struct Division<'a>(PhantomData<&'a ()>);
21
22impl<'a> Division<'a> {
23    /// Create a new [`Division`] circuit.
24    pub fn new() -> Self {
25        Default::default()
26    }
27}
28
29impl<'a, F: FancyBinary + FancyArithmetic + FancyProj + CrtGadgets> Circuit<F> for Division<'a>
30where
31    F::Item: 'a,
32{
33    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
34    type Output = CrtBundle<F::Item>;
35
36    fn execute(
37        &self,
38        backend: &mut F,
39        inputs: Self::Input,
40        channel: &mut Channel,
41    ) -> Result<Self::Output> {
42        let (x, y) = inputs;
43        assert_eq!(x.moduli(), y.moduli());
44
45        let q = x.composite_modulus();
46
47        // Compute l based on the assumption that the last prime is unused.
48        let nprimes = x.moduli().len();
49        let qs_ = &x.moduli()[..nprimes - 1];
50        let q_ = product(qs_);
51        let l = 128 - q_.leading_zeros();
52
53        let mut quotient = Constant::new(0, q).execute(backend, (), channel)?;
54        let mut a = (*x).clone();
55
56        let one = Constant::new(1, q).execute(backend, (), channel)?;
57        for i in 0..l {
58            let b = 2u128.pow(l - i - 1);
59            let mut pb = q_ / b;
60            if q_.is_multiple_of(b) {
61                pb -= 1;
62            }
63
64            let tmp = ConstantMultiplication::new().execute(backend, (y, b), channel)?;
65            let c1 = PmrGreaterThanOrEqual::new().execute(backend, (&a, &tmp), channel)?;
66
67            let pb_crt = Constant::new(pb, q).execute(backend, (), channel)?;
68            let c2 = PmrGreaterThanOrEqual::new().execute(backend, (&pb_crt, y), channel)?;
69
70            let c = backend.and(&c1, &c2, channel)?;
71
72            let c_ws = one
73                .iter()
74                .map(|w| backend.mul(w, &c, channel))
75                .collect::<Result<Vec<_>>>()?;
76            let c_crt = CrtBundle::new(c_ws);
77
78            let b_if = ConstantMultiplication::new().execute(backend, (&c_crt, b), channel)?;
79            quotient = Addition::new().execute(backend, (&quotient, &b_if), channel)?;
80
81            let tmp_if = Multiplication::new().execute(backend, (&c_crt, &tmp), channel)?;
82            a = Subtraction::new().execute(backend, (&a, &tmp_if), channel)?;
83        }
84
85        Ok(quotient)
86    }
87}
88
89#[cfg(test)]
90mod test {
91    use crate::{
92        circuits::arithmetic::Division,
93        dummy::{Dummy, DummyVal},
94        util::{RngExt, product},
95    };
96    use rand::{Rng, thread_rng};
97
98    #[test]
99    fn division() {
100        let mut rng = thread_rng();
101
102        for _ in 0..2 {
103            let qs = rng.gen_usable_factors();
104            let q = product(&qs);
105            let q_ = product(&qs[..qs.len() - 1]);
106            let x = rng.r#gen::<u128>() % q_;
107            let y = rng.r#gen::<u128>() % q_;
108            let x_input = DummyVal::to_crt(x, q);
109            let y_input = DummyVal::to_crt(y, q);
110            let z = Dummy::eval(&Division::new(), (&x_input, &y_input)).unwrap();
111            let output = DummyVal::from_crt(&z, q);
112            assert_eq!(output, x / y);
113        }
114    }
115}