fancy_garbling/circuits/arithmetic/
addition.rs1use crate::{CrtBundle, FancyArithmetic, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6#[derive(Default)]
8pub struct Addition<'a>(PhantomData<&'a ()>);
9
10impl<'a> Addition<'a> {
11 pub fn new() -> Self {
13 Default::default()
14 }
15}
16
17impl<'a, F: FancyArithmetic> Circuit<F> for Addition<'a>
18where
19 F::Item: 'a,
20{
21 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
22 type Output = CrtBundle<F::Item>;
23
24 fn execute(
25 &self,
26 backend: &mut F,
27 inputs: Self::Input,
28 _: &mut Channel,
29 ) -> Result<Self::Output> {
30 let (x, y) = inputs;
31 assert_eq!(x.size(), y.size(), "`x` and `y` must be the same length");
32 Ok(CrtBundle::new(
33 x.wires()
34 .iter()
35 .zip(y.wires().iter())
36 .map(|(x, y)| backend.add(x, y))
37 .collect(),
38 ))
39 }
40}
41
42#[derive(Default)]
44pub struct AddMany<'a>(PhantomData<&'a ()>);
45
46impl<'a> AddMany<'a> {
47 pub fn new() -> Self {
49 Default::default()
50 }
51}
52
53impl<'a, F: FancyArithmetic> Circuit<F> for AddMany<'a>
54where
55 F::Item: 'a,
56{
57 type Input = &'a [F::Item];
58 type Output = F::Item;
59
60 fn execute(
61 &self,
62 backend: &mut F,
63 inputs: Self::Input,
64 _: &mut Channel,
65 ) -> Result<Self::Output> {
66 assert!(inputs.len() >= 2, "`args.len()` must be two or more");
67 let mut z = inputs[0].clone();
68 for x in inputs.iter().skip(1) {
69 z = backend.add(&z, x);
70 }
71 Ok(z)
72 }
73}
74
75#[cfg(test)]
76mod test {
77 use crate::{
78 circuits::arithmetic::Addition,
79 dummy::{Dummy, DummyVal},
80 util::RngExt,
81 };
82 use rand::{Rng, thread_rng};
83
84 #[test]
85 fn addition() {
86 let mut rng = thread_rng();
87 let q = rng.gen_usable_composite_modulus();
88
89 for _ in 0..16 {
90 let x = rng.r#gen::<u128>() % q;
91 let y = rng.r#gen::<u128>() % q;
92 let x_input = DummyVal::to_crt(x, q);
93 let y_input = DummyVal::to_crt(y, q);
94 let z = Dummy::eval(&Addition::new(), (&x_input, &y_input)).unwrap();
95 let output = DummyVal::from_crt(&z, q);
96 assert_eq!(output, (x + y) % q);
97 }
98 }
99}