Skip to main content

fancy_garbling/circuits/arithmetic/
multiplication.rs

1use crate::{CrtBundle, FancyArithmetic, circuit::Circuit, util::crt};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// Given [`CrtBundle`]s `x` and `y`, output `x * y`.
7#[derive(Default)]
8pub struct Multiplication<'a>(PhantomData<&'a ()>);
9
10impl<'a> Multiplication<'a> {
11    /// Create a new [`Multiplication`] circuit.
12    pub fn new() -> Self {
13        Default::default()
14    }
15}
16
17impl<'a, F: FancyArithmetic> Circuit<F> for Multiplication<'a>
18where
19    F::Item: 'a,
20{
21    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
22    type Output = CrtBundle<F::Item>;
23
24    fn execute(
25        &self,
26        backend: &mut F,
27        inputs: Self::Input,
28        channel: &mut Channel,
29    ) -> Result<Self::Output> {
30        let (x, y) = inputs;
31        assert_eq!(x.size(), y.size());
32        let bundle = x
33            .wires()
34            .iter()
35            .zip(y.wires().iter())
36            .map(|(x, y)| backend.mul(x, y, channel))
37            .collect::<Result<Vec<_>>>()?;
38        Ok(CrtBundle::new(bundle))
39    }
40}
41
42/// Given [`CrtBundle`] `x` and constant `c`, output `x * c`.
43#[derive(Default)]
44pub struct ConstantMultiplication<'a>(PhantomData<&'a ()>);
45
46impl<'a> ConstantMultiplication<'a> {
47    /// Create a new [`ConstantMultiplication`] circuit.
48    pub fn new() -> Self {
49        Default::default()
50    }
51}
52
53impl<'a, F: FancyArithmetic> Circuit<F> for ConstantMultiplication<'a>
54where
55    F::Item: 'a,
56{
57    type Input = (&'a CrtBundle<F::Item>, u128);
58    type Output = CrtBundle<F::Item>;
59
60    fn execute(
61        &self,
62        backend: &mut F,
63        inputs: Self::Input,
64        _: &mut Channel,
65    ) -> Result<Self::Output> {
66        let (x, c) = inputs;
67        let cs = crt(c, &x.moduli());
68        Ok(CrtBundle::new(
69            x.wires()
70                .iter()
71                .zip(cs)
72                .map(|(x, c)| backend.cmul(x, c))
73                .collect::<Vec<_>>(),
74        ))
75    }
76}
77
78#[cfg(test)]
79mod test {
80    use crate::{
81        circuits::arithmetic::{ConstantMultiplication, Multiplication},
82        dummy::{Dummy, DummyVal},
83        util::RngExt,
84    };
85    use rand::{Rng, thread_rng};
86
87    #[test]
88    fn multiplication() {
89        let mut rng = thread_rng();
90        let q = rng.gen_usable_composite_modulus();
91
92        for _ in 0..16 {
93            let x = rng.r#gen::<u64>() as u128 % q;
94            let y = rng.r#gen::<u64>() as u128 % q;
95            let x_input = DummyVal::to_crt(x, q);
96            let y_input = DummyVal::to_crt(y, q);
97            let circuit = Multiplication::new();
98            let z = Dummy::eval(&circuit, (&x_input, &y_input)).unwrap();
99            let output = DummyVal::from_crt(&z, q);
100            assert_eq!(output, (x * y) % q);
101        }
102    }
103
104    #[test]
105    fn constant_multiplication() {
106        let mut rng = thread_rng();
107        let q = rng.gen_usable_composite_modulus();
108
109        for _ in 0..16 {
110            let x = rng.r#gen::<u64>() as u128 % q;
111            let c = rng.r#gen::<u64>() as u128 % q;
112            let x_input = DummyVal::to_crt(x, q);
113            let circuit = ConstantMultiplication::new();
114            let z = Dummy::eval(&circuit, (&x_input, c)).unwrap();
115            let output = DummyVal::from_crt(&z, q);
116            assert_eq!(output, (x * c) % q);
117        }
118    }
119}