Skip to main content

fancy_garbling/circuits/arithmetic/
mixed_radix.rs

1use crate::{
2    CrtBundle, FancyArithmetic, FancyProj, HasModulus,
3    circuit::Circuit,
4    circuits::arithmetic::{ModChange, addition::AddMany},
5    util::{as_mixed_radix, inv, product},
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11#[derive(Default)]
12struct MixedRadixAdditionMSBOnly<'a>(PhantomData<&'a ()>);
13
14impl<'a> MixedRadixAdditionMSBOnly<'a> {
15    /// Create a new [`MixedRadixAdditionMSBOnly`] circuit.
16    pub fn new() -> Self {
17        Default::default()
18    }
19}
20
21impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for MixedRadixAdditionMSBOnly<'a>
22where
23    F::Item: 'a,
24{
25    type Input = &'a [CrtBundle<F::Item>];
26    type Output = F::Item;
27
28    fn execute(
29        &self,
30        backend: &mut F,
31        inputs: Self::Input,
32        channel: &mut Channel,
33    ) -> Result<Self::Output> {
34        let xs = inputs;
35        assert!(!xs.is_empty(), "`inputs` cannot be empty");
36        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
37
38        let nargs = xs.len();
39        let n = xs[0].wires().len();
40
41        let mut opt_carry = None;
42        let mut max_carry = 0;
43
44        for i in 0..n - 1 {
45            // all the ith digits, in one vec
46            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect::<Vec<_>>();
47            // compute the carry
48            let q = xs[0].moduli()[i];
49            // max_carry currently contains the max carry from the previous iteration
50            let max_val = nargs as u16 * (q - 1) + max_carry;
51            // now it is the max carry of this iteration
52            max_carry = max_val / q;
53
54            // mod change the digits to the max sum possible plus the max carry of the
55            // previous iteration
56            let modded_ds = ds
57                .into_iter()
58                .map(|d| ModChange.execute(backend, (d, max_val + 1), channel))
59                .collect::<swanky_error::Result<Vec<_>>>()?;
60            // add them up
61            let sum = AddMany::new().execute(backend, modded_ds.as_slice(), channel)?;
62            // add in the carry
63            let sum_with_carry = opt_carry
64                .as_ref()
65                .map_or(sum.clone(), |c| backend.add(&sum, c));
66
67            // carry now contains the carry information, we just have to project it to
68            // the correct moduli for the next iteration. It will either be used to
69            // compute the next carry, if i < n-2, or it will be used to compute the
70            // output MSB, in which case it should be the modulus of the SB
71            let next_mod = if i < n - 2 {
72                nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
73            } else {
74                inputs[0].moduli()[i + 1] // we will be adding the carry to the MSB
75            };
76
77            let tt = (0..=max_val)
78                .map(|i| (i / q) % next_mod)
79                .collect::<Vec<_>>();
80            opt_carry = Some(backend.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
81        }
82
83        // compute the msb
84        let ds = xs
85            .iter()
86            .map(|x| x.wires()[n - 1].clone())
87            .collect::<Vec<_>>();
88        let digit_sum = AddMany::new().execute(backend, ds.as_slice(), channel)?;
89        Ok(opt_carry
90            .as_ref()
91            .map_or(digit_sum.clone(), |d| backend.add(&digit_sum, d)))
92    }
93}
94
95/// Mixed radix addition.
96#[derive(Default)]
97pub struct MixedRadixAddition<'a>(PhantomData<&'a ()>);
98
99impl<'a> MixedRadixAddition<'a> {
100    /// Create a new [`MixedRadixAddition`] circuit.
101    pub fn new() -> Self {
102        Default::default()
103    }
104}
105
106impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for MixedRadixAddition<'a>
107where
108    F::Item: 'a,
109{
110    type Input = &'a [CrtBundle<F::Item>];
111    type Output = CrtBundle<F::Item>;
112
113    fn execute(
114        &self,
115        backend: &mut F,
116        inputs: Self::Input,
117        channel: &mut Channel,
118    ) -> Result<Self::Output> {
119        let xs = inputs;
120        assert!(!xs.is_empty(), "`xs` cannot be empty");
121        assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
122
123        let nargs = xs.len();
124        let n = xs[0].wires().len();
125
126        let mut digit_carry = None;
127        let mut carry_carry = None;
128        let mut max_carry = 0;
129
130        let mut res = Vec::with_capacity(n);
131
132        for i in 0..n {
133            // all the ith digits, in one vec
134            let ds = xs.iter().map(|x| x.wires()[i].clone()).collect::<Vec<_>>();
135
136            // compute the digit -- easy
137            let digit_sum = AddMany::new().execute(backend, ds.as_slice(), channel)?;
138            let digit = digit_carry.map_or(digit_sum.clone(), |d| backend.add(&digit_sum, &d));
139
140            if i < n - 1 {
141                // compute the carries
142                let q = xs[0].wires()[i].modulus();
143                // max_carry currently contains the max carry from the previous iteration
144                let max_val = nargs as u16 * (q - 1) + max_carry;
145                // now it is the max carry of this iteration
146                max_carry = max_val / q;
147
148                let modded_ds = ds
149                    .into_iter()
150                    .map(|d| ModChange.execute(backend, (d, max_val + 1), channel))
151                    .collect::<Result<Vec<_>>>()?;
152
153                let carry_sum = AddMany::new().execute(backend, modded_ds.as_slice(), channel)?;
154                // add in the carry from the previous iteration
155                let carry = carry_carry.map_or(carry_sum.clone(), |c| backend.add(&carry_sum, &c));
156
157                // carry now contains the carry information, we just have to project it to
158                // the correct moduli for the next iteration
159                let next_mod = xs[0].wires()[i + 1].modulus();
160                let tt = (0..=max_val)
161                    .map(|i| (i / q) % next_mod)
162                    .collect::<Vec<_>>();
163                digit_carry = Some(backend.proj(&carry, next_mod, Some(tt), channel)?);
164
165                let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
166
167                if i < n - 2 {
168                    if max_carry < next_mod {
169                        carry_carry = Some(ModChange.execute(
170                            backend,
171                            (digit_carry.as_ref().unwrap().clone(), next_max_val + 1),
172                            channel,
173                        )?);
174                    } else {
175                        let tt = (0..=max_val).map(|i| i / q).collect::<Vec<_>>();
176                        carry_carry =
177                            Some(backend.proj(&carry, next_max_val + 1, Some(tt), channel)?);
178                    }
179                } else {
180                    // next digit is MSB so we dont need carry_carry
181                    carry_carry = None;
182                }
183            } else {
184                digit_carry = None;
185                carry_carry = None;
186            }
187            res.push(digit);
188        }
189        Ok(CrtBundle::new(res))
190    }
191}
192
193/// For input [`CrtBundle`] `x` and vector of moduli `ms`, output the MSB of the
194/// fractional part of `x / M`, where `M = product(ms)`.
195#[derive(Default)]
196pub struct FractionalMixedRadix<'a>(PhantomData<&'a ()>);
197
198impl<'a> FractionalMixedRadix<'a> {
199    /// Create a new [`FractionalMixedRadix`] circuit.
200    pub fn new() -> Self {
201        Default::default()
202    }
203}
204
205impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for FractionalMixedRadix<'a>
206where
207    F::Item: 'a,
208{
209    type Input = (&'a CrtBundle<F::Item>, &'a [u16]);
210    type Output = F::Item;
211
212    fn execute(
213        &self,
214        backend: &mut F,
215        inputs: Self::Input,
216        channel: &mut Channel,
217    ) -> Result<Self::Output> {
218        let (bun, ms) = inputs;
219
220        let ndigits = ms.len();
221
222        let q = product(&bun.moduli());
223        let M = product(ms);
224
225        let mut ds = Vec::new();
226
227        for wire in bun.wires().iter() {
228            let p = wire.modulus();
229
230            let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
231
232            for x in 0..p {
233                let crt_coef = inv(((q / p as u128) % p as u128) as i128, p as i128);
234                let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
235                let digits = as_mixed_radix(y, ms);
236                for i in 0..ndigits {
237                    tabs[i].push(digits[i]);
238                }
239            }
240
241            let new_ds = tabs
242                .into_iter()
243                .enumerate()
244                .map(|(i, tt)| backend.proj(wire, ms[i], Some(tt), channel))
245                .collect::<Result<Vec<_>>>()?;
246
247            ds.push(CrtBundle::new(new_ds));
248        }
249
250        MixedRadixAdditionMSBOnly::new().execute(backend, ds.as_slice(), channel)
251    }
252}
253
254#[cfg(test)]
255mod test {
256    use rand::{Rng, thread_rng};
257
258    use crate::{
259        circuits::arithmetic::mixed_radix::{MixedRadixAddition, MixedRadixAdditionMSBOnly},
260        dummy::{Dummy, DummyVal},
261        util::{RngExt, as_mixed_radix, product},
262    };
263
264    #[test]
265    fn mixed_radix_addition_msb_only() {
266        let mut rng = thread_rng();
267        let nargs = 2 + rng.r#gen::<usize>() % 10;
268        let moduli = (0..7).map(|_| rng.gen_modulus()).collect::<Vec<_>>();
269        let q = product(&moduli);
270
271        // Test maximum overflow.
272        let inputs = (0..nargs)
273            .map(|_| DummyVal::to_mixed_radix(q - 1, &moduli))
274            .collect::<Vec<_>>();
275        let output = Dummy::eval(&MixedRadixAdditionMSBOnly::new(), inputs.as_slice()).unwrap();
276        assert_eq!(
277            output.val(),
278            *as_mixed_radix((q - 1) * (nargs as u128) % q, &moduli)
279                .last()
280                .unwrap()
281        );
282
283        // Test random values.
284        for _ in 0..4 {
285            let mut expected = 0;
286            let mut inputs = Vec::new();
287            for _ in 0..nargs {
288                let x = rng.gen_u128() % q;
289                expected = (expected + x) % q;
290                inputs.push(DummyVal::to_mixed_radix(x, &moduli));
291            }
292            let output = Dummy::eval(&MixedRadixAdditionMSBOnly::new(), inputs.as_slice()).unwrap();
293            assert_eq!(
294                output.val(),
295                *as_mixed_radix(expected, &moduli).last().unwrap()
296            );
297        }
298    }
299
300    #[test]
301    fn test_mixed_radix_addition() {
302        let mut rng = thread_rng();
303        let nargs = 2 + rng.gen_usize() % 100;
304        let moduli = (0..7).map(|_| rng.gen_modulus()).collect::<Vec<_>>();
305        let q: u128 = moduli.iter().map(|&q| q as u128).product();
306
307        // Test maximum overflow.
308        let inputs = (0..nargs)
309            .map(|_| DummyVal::to_mixed_radix(q - 1, &moduli))
310            .collect::<Vec<_>>();
311        let output = Dummy::eval(&MixedRadixAddition::new(), inputs.as_slice()).unwrap();
312        assert_eq!(
313            DummyVal::from_mixed_radix(&output),
314            (q - 1) * (nargs as u128) % q
315        );
316
317        // Test random values.
318        for _ in 0..4 {
319            let mut expected = 0;
320            let mut inputs = Vec::new();
321            for _ in 0..nargs {
322                let x = rng.gen_u128() % q;
323                expected = (expected + x) % q;
324                inputs.push(DummyVal::to_mixed_radix(x, &moduli));
325            }
326            let output = Dummy::eval(&MixedRadixAddition::new(), inputs.as_slice()).unwrap();
327            assert_eq!(DummyVal::from_mixed_radix(&output), expected);
328        }
329    }
330}