fancy_garbling/circuits/arithmetic/
mixed_radix.rs1use crate::{
2 CrtBundle, FancyArithmetic, FancyProj, HasModulus,
3 circuit::Circuit,
4 circuits::arithmetic::{ModChange, addition::AddMany},
5 util::{as_mixed_radix, inv, product},
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11#[derive(Default)]
12struct MixedRadixAdditionMSBOnly<'a>(PhantomData<&'a ()>);
13
14impl<'a> MixedRadixAdditionMSBOnly<'a> {
15 pub fn new() -> Self {
17 Default::default()
18 }
19}
20
21impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for MixedRadixAdditionMSBOnly<'a>
22where
23 F::Item: 'a,
24{
25 type Input = &'a [CrtBundle<F::Item>];
26 type Output = F::Item;
27
28 fn execute(
29 &self,
30 backend: &mut F,
31 inputs: Self::Input,
32 channel: &mut Channel,
33 ) -> Result<Self::Output> {
34 let xs = inputs;
35 assert!(!xs.is_empty(), "`inputs` cannot be empty");
36 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
37
38 let nargs = xs.len();
39 let n = xs[0].wires().len();
40
41 let mut opt_carry = None;
42 let mut max_carry = 0;
43
44 for i in 0..n - 1 {
45 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect::<Vec<_>>();
47 let q = xs[0].moduli()[i];
49 let max_val = nargs as u16 * (q - 1) + max_carry;
51 max_carry = max_val / q;
53
54 let modded_ds = ds
57 .into_iter()
58 .map(|d| ModChange.execute(backend, (d, max_val + 1), channel))
59 .collect::<swanky_error::Result<Vec<_>>>()?;
60 let sum = AddMany::new().execute(backend, modded_ds.as_slice(), channel)?;
62 let sum_with_carry = opt_carry
64 .as_ref()
65 .map_or(sum.clone(), |c| backend.add(&sum, c));
66
67 let next_mod = if i < n - 2 {
72 nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
73 } else {
74 inputs[0].moduli()[i + 1] };
76
77 let tt = (0..=max_val)
78 .map(|i| (i / q) % next_mod)
79 .collect::<Vec<_>>();
80 opt_carry = Some(backend.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
81 }
82
83 let ds = xs
85 .iter()
86 .map(|x| x.wires()[n - 1].clone())
87 .collect::<Vec<_>>();
88 let digit_sum = AddMany::new().execute(backend, ds.as_slice(), channel)?;
89 Ok(opt_carry
90 .as_ref()
91 .map_or(digit_sum.clone(), |d| backend.add(&digit_sum, d)))
92 }
93}
94
95#[derive(Default)]
97pub struct MixedRadixAddition<'a>(PhantomData<&'a ()>);
98
99impl<'a> MixedRadixAddition<'a> {
100 pub fn new() -> Self {
102 Default::default()
103 }
104}
105
106impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for MixedRadixAddition<'a>
107where
108 F::Item: 'a,
109{
110 type Input = &'a [CrtBundle<F::Item>];
111 type Output = CrtBundle<F::Item>;
112
113 fn execute(
114 &self,
115 backend: &mut F,
116 inputs: Self::Input,
117 channel: &mut Channel,
118 ) -> Result<Self::Output> {
119 let xs = inputs;
120 assert!(!xs.is_empty(), "`xs` cannot be empty");
121 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
122
123 let nargs = xs.len();
124 let n = xs[0].wires().len();
125
126 let mut digit_carry = None;
127 let mut carry_carry = None;
128 let mut max_carry = 0;
129
130 let mut res = Vec::with_capacity(n);
131
132 for i in 0..n {
133 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect::<Vec<_>>();
135
136 let digit_sum = AddMany::new().execute(backend, ds.as_slice(), channel)?;
138 let digit = digit_carry.map_or(digit_sum.clone(), |d| backend.add(&digit_sum, &d));
139
140 if i < n - 1 {
141 let q = xs[0].wires()[i].modulus();
143 let max_val = nargs as u16 * (q - 1) + max_carry;
145 max_carry = max_val / q;
147
148 let modded_ds = ds
149 .into_iter()
150 .map(|d| ModChange.execute(backend, (d, max_val + 1), channel))
151 .collect::<Result<Vec<_>>>()?;
152
153 let carry_sum = AddMany::new().execute(backend, modded_ds.as_slice(), channel)?;
154 let carry = carry_carry.map_or(carry_sum.clone(), |c| backend.add(&carry_sum, &c));
156
157 let next_mod = xs[0].wires()[i + 1].modulus();
160 let tt = (0..=max_val)
161 .map(|i| (i / q) % next_mod)
162 .collect::<Vec<_>>();
163 digit_carry = Some(backend.proj(&carry, next_mod, Some(tt), channel)?);
164
165 let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
166
167 if i < n - 2 {
168 if max_carry < next_mod {
169 carry_carry = Some(ModChange.execute(
170 backend,
171 (digit_carry.as_ref().unwrap().clone(), next_max_val + 1),
172 channel,
173 )?);
174 } else {
175 let tt = (0..=max_val).map(|i| i / q).collect::<Vec<_>>();
176 carry_carry =
177 Some(backend.proj(&carry, next_max_val + 1, Some(tt), channel)?);
178 }
179 } else {
180 carry_carry = None;
182 }
183 } else {
184 digit_carry = None;
185 carry_carry = None;
186 }
187 res.push(digit);
188 }
189 Ok(CrtBundle::new(res))
190 }
191}
192
193#[derive(Default)]
196pub struct FractionalMixedRadix<'a>(PhantomData<&'a ()>);
197
198impl<'a> FractionalMixedRadix<'a> {
199 pub fn new() -> Self {
201 Default::default()
202 }
203}
204
205impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for FractionalMixedRadix<'a>
206where
207 F::Item: 'a,
208{
209 type Input = (&'a CrtBundle<F::Item>, &'a [u16]);
210 type Output = F::Item;
211
212 fn execute(
213 &self,
214 backend: &mut F,
215 inputs: Self::Input,
216 channel: &mut Channel,
217 ) -> Result<Self::Output> {
218 let (bun, ms) = inputs;
219
220 let ndigits = ms.len();
221
222 let q = product(&bun.moduli());
223 let M = product(ms);
224
225 let mut ds = Vec::new();
226
227 for wire in bun.wires().iter() {
228 let p = wire.modulus();
229
230 let mut tabs = vec![Vec::with_capacity(p as usize); ndigits];
231
232 for x in 0..p {
233 let crt_coef = inv(((q / p as u128) % p as u128) as i128, p as i128);
234 let y = (M as f64 * x as f64 * crt_coef as f64 / p as f64).round() as u128 % M;
235 let digits = as_mixed_radix(y, ms);
236 for i in 0..ndigits {
237 tabs[i].push(digits[i]);
238 }
239 }
240
241 let new_ds = tabs
242 .into_iter()
243 .enumerate()
244 .map(|(i, tt)| backend.proj(wire, ms[i], Some(tt), channel))
245 .collect::<Result<Vec<_>>>()?;
246
247 ds.push(CrtBundle::new(new_ds));
248 }
249
250 MixedRadixAdditionMSBOnly::new().execute(backend, ds.as_slice(), channel)
251 }
252}
253
254#[cfg(test)]
255mod test {
256 use rand::{Rng, thread_rng};
257
258 use crate::{
259 circuits::arithmetic::mixed_radix::{MixedRadixAddition, MixedRadixAdditionMSBOnly},
260 dummy::{Dummy, DummyVal},
261 util::{RngExt, as_mixed_radix, product},
262 };
263
264 #[test]
265 fn mixed_radix_addition_msb_only() {
266 let mut rng = thread_rng();
267 let nargs = 2 + rng.r#gen::<usize>() % 10;
268 let moduli = (0..7).map(|_| rng.gen_modulus()).collect::<Vec<_>>();
269 let q = product(&moduli);
270
271 let inputs = (0..nargs)
273 .map(|_| DummyVal::to_mixed_radix(q - 1, &moduli))
274 .collect::<Vec<_>>();
275 let output = Dummy::eval(&MixedRadixAdditionMSBOnly::new(), inputs.as_slice()).unwrap();
276 assert_eq!(
277 output.val(),
278 *as_mixed_radix((q - 1) * (nargs as u128) % q, &moduli)
279 .last()
280 .unwrap()
281 );
282
283 for _ in 0..4 {
285 let mut expected = 0;
286 let mut inputs = Vec::new();
287 for _ in 0..nargs {
288 let x = rng.gen_u128() % q;
289 expected = (expected + x) % q;
290 inputs.push(DummyVal::to_mixed_radix(x, &moduli));
291 }
292 let output = Dummy::eval(&MixedRadixAdditionMSBOnly::new(), inputs.as_slice()).unwrap();
293 assert_eq!(
294 output.val(),
295 *as_mixed_radix(expected, &moduli).last().unwrap()
296 );
297 }
298 }
299
300 #[test]
301 fn test_mixed_radix_addition() {
302 let mut rng = thread_rng();
303 let nargs = 2 + rng.gen_usize() % 100;
304 let moduli = (0..7).map(|_| rng.gen_modulus()).collect::<Vec<_>>();
305 let q: u128 = moduli.iter().map(|&q| q as u128).product();
306
307 let inputs = (0..nargs)
309 .map(|_| DummyVal::to_mixed_radix(q - 1, &moduli))
310 .collect::<Vec<_>>();
311 let output = Dummy::eval(&MixedRadixAddition::new(), inputs.as_slice()).unwrap();
312 assert_eq!(
313 DummyVal::from_mixed_radix(&output),
314 (q - 1) * (nargs as u128) % q
315 );
316
317 for _ in 0..4 {
319 let mut expected = 0;
320 let mut inputs = Vec::new();
321 for _ in 0..nargs {
322 let x = rng.gen_u128() % q;
323 expected = (expected + x) % q;
324 inputs.push(DummyVal::to_mixed_radix(x, &moduli));
325 }
326 let output = Dummy::eval(&MixedRadixAddition::new(), inputs.as_slice()).unwrap();
327 assert_eq!(DummyVal::from_mixed_radix(&output), expected);
328 }
329 }
330}