fancy_garbling/circuits/arithmetic/
multiplication.rs1use crate::{CrtBundle, FancyArithmetic, circuit::Circuit, util::crt};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6#[derive(Default)]
8pub struct Multiplication<'a>(PhantomData<&'a ()>);
9
10impl<'a> Multiplication<'a> {
11 pub fn new() -> Self {
13 Default::default()
14 }
15}
16
17impl<'a, F: FancyArithmetic> Circuit<F> for Multiplication<'a>
18where
19 F::Item: 'a,
20{
21 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
22 type Output = CrtBundle<F::Item>;
23
24 fn execute(
25 &self,
26 backend: &mut F,
27 inputs: Self::Input,
28 channel: &mut Channel,
29 ) -> Result<Self::Output> {
30 let (x, y) = inputs;
31 assert_eq!(x.size(), y.size());
32 let bundle = x
33 .wires()
34 .iter()
35 .zip(y.wires().iter())
36 .map(|(x, y)| backend.mul(x, y, channel))
37 .collect::<Result<Vec<_>>>()?;
38 Ok(CrtBundle::new(bundle))
39 }
40}
41
42#[derive(Default)]
44pub struct ConstantMultiplication<'a>(PhantomData<&'a ()>);
45
46impl<'a> ConstantMultiplication<'a> {
47 pub fn new() -> Self {
49 Default::default()
50 }
51}
52
53impl<'a, F: FancyArithmetic> Circuit<F> for ConstantMultiplication<'a>
54where
55 F::Item: 'a,
56{
57 type Input = (&'a CrtBundle<F::Item>, u128);
58 type Output = CrtBundle<F::Item>;
59
60 fn execute(
61 &self,
62 backend: &mut F,
63 inputs: Self::Input,
64 _: &mut Channel,
65 ) -> Result<Self::Output> {
66 let (x, c) = inputs;
67 let cs = crt(c, &x.moduli());
68 Ok(CrtBundle::new(
69 x.wires()
70 .iter()
71 .zip(cs)
72 .map(|(x, c)| backend.cmul(x, c))
73 .collect::<Vec<_>>(),
74 ))
75 }
76}
77
78#[cfg(test)]
79mod test {
80 use crate::{
81 circuits::arithmetic::{ConstantMultiplication, Multiplication},
82 dummy::{Dummy, DummyVal},
83 util::RngExt,
84 };
85 use rand::{Rng, thread_rng};
86
87 #[test]
88 fn multiplication() {
89 let mut rng = thread_rng();
90 let q = rng.gen_usable_composite_modulus();
91
92 for _ in 0..16 {
93 let x = rng.r#gen::<u64>() as u128 % q;
94 let y = rng.r#gen::<u64>() as u128 % q;
95 let x_input = DummyVal::to_crt(x, q);
96 let y_input = DummyVal::to_crt(y, q);
97 let circuit = Multiplication::new();
98 let z = Dummy::eval(&circuit, (&x_input, &y_input)).unwrap();
99 let output = DummyVal::from_crt(&z, q);
100 assert_eq!(output, (x * y) % q);
101 }
102 }
103
104 #[test]
105 fn constant_multiplication() {
106 let mut rng = thread_rng();
107 let q = rng.gen_usable_composite_modulus();
108
109 for _ in 0..16 {
110 let x = rng.r#gen::<u64>() as u128 % q;
111 let c = rng.r#gen::<u64>() as u128 % q;
112 let x_input = DummyVal::to_crt(x, q);
113 let circuit = ConstantMultiplication::new();
114 let z = Dummy::eval(&circuit, (&x_input, c)).unwrap();
115 let output = DummyVal::from_crt(&z, q);
116 assert_eq!(output, (x * c) % q);
117 }
118 }
119}