Skip to main content

fancy_garbling/circuits/arithmetic/
addition.rs

1use crate::{CrtBundle, FancyArithmetic, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// Given [`CrtBundle`]s `x` and `y`, output `x + y`.
7#[derive(Default)]
8pub struct Addition<'a>(PhantomData<&'a ()>);
9
10impl<'a> Addition<'a> {
11    /// Create a new [`Addition`] circuit.
12    pub fn new() -> Self {
13        Default::default()
14    }
15}
16
17impl<'a, F: FancyArithmetic> Circuit<F> for Addition<'a>
18where
19    F::Item: 'a,
20{
21    type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
22    type Output = CrtBundle<F::Item>;
23
24    fn execute(
25        &self,
26        backend: &mut F,
27        inputs: Self::Input,
28        _: &mut Channel,
29    ) -> Result<Self::Output> {
30        let (x, y) = inputs;
31        assert_eq!(x.size(), y.size(), "`x` and `y` must be the same length");
32        Ok(CrtBundle::new(
33            x.wires()
34                .iter()
35                .zip(y.wires().iter())
36                .map(|(x, y)| backend.add(x, y))
37                .collect(),
38        ))
39    }
40}
41
42/// Given inputs `x`, output `sum(x)`.
43#[derive(Default)]
44pub struct AddMany<'a>(PhantomData<&'a ()>);
45
46impl<'a> AddMany<'a> {
47    /// Create a new [`AddMany`] circuit.
48    pub fn new() -> Self {
49        Default::default()
50    }
51}
52
53impl<'a, F: FancyArithmetic> Circuit<F> for AddMany<'a>
54where
55    F::Item: 'a,
56{
57    type Input = &'a [F::Item];
58    type Output = F::Item;
59
60    fn execute(
61        &self,
62        backend: &mut F,
63        inputs: Self::Input,
64        _: &mut Channel,
65    ) -> Result<Self::Output> {
66        assert!(inputs.len() >= 2, "`args.len()` must be two or more");
67        let mut z = inputs[0].clone();
68        for x in inputs.iter().skip(1) {
69            z = backend.add(&z, x);
70        }
71        Ok(z)
72    }
73}
74
75#[cfg(test)]
76mod test {
77    use crate::{
78        circuits::arithmetic::Addition,
79        dummy::{Dummy, DummyVal},
80        util::RngExt,
81    };
82    use rand::{Rng, thread_rng};
83
84    #[test]
85    fn addition() {
86        let mut rng = thread_rng();
87        let q = rng.gen_usable_composite_modulus();
88
89        for _ in 0..16 {
90            let x = rng.r#gen::<u128>() % q;
91            let y = rng.r#gen::<u128>() % q;
92            let x_input = DummyVal::to_crt(x, q);
93            let y_input = DummyVal::to_crt(y, q);
94            let z = Dummy::eval(&Addition::new(), (&x_input, &y_input)).unwrap();
95            let output = DummyVal::from_crt(&z, q);
96            assert_eq!(output, (x + y) % q);
97        }
98    }
99}