1use crate::{
2 FancyArithmetic, FancyBinary, FancyProj,
3 fancy::{Fancy, HasModulus},
4};
5use itertools::Itertools;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::ops::Index;
9use swanky_channel::Channel;
10
11#[derive(Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct Bundle<W>(Vec<W>);
15
16impl<W: Clone + HasModulus> Bundle<W> {
17 pub fn new(ws: Vec<W>) -> Bundle<W> {
19 Bundle(ws)
20 }
21
22 pub fn moduli(&self) -> Vec<u16> {
24 self.0.iter().map(HasModulus::modulus).collect()
25 }
26
27 pub fn wires(&self) -> &[W] {
29 &self.0
30 }
31
32 pub fn size(&self) -> usize {
34 self.0.len()
35 }
36
37 pub fn is_binary(&self) -> bool {
39 self.moduli().iter().all(|m| *m == 2)
40 }
41
42 pub fn with_moduli(&self, moduli: &[u16]) -> Bundle<W> {
44 let old_ws = self.wires();
45 let mut new_ws = Vec::with_capacity(moduli.len());
46 for &p in moduli {
47 if let Some(w) = old_ws.iter().find(|&x| x.modulus() == p) {
48 new_ws.push(w.clone());
49 } else {
50 panic!("Bundle::with_moduli: no {} modulus in bundle", p);
51 }
52 }
53 Bundle(new_ws)
54 }
55
56 pub fn pad(&mut self, val: &W, n: usize) {
58 for _ in 0..n {
59 self.0.push(val.clone());
60 }
61 }
62
63 pub fn extract(&mut self, wire_index: usize) -> W {
65 self.0.remove(wire_index)
66 }
67
68 pub fn insert(&mut self, wire_index: usize, val: W) {
70 self.0.insert(wire_index, val)
71 }
72
73 pub fn push(&mut self, val: W) {
75 self.0.push(val);
76 }
77
78 pub fn pop(&mut self) -> Option<W> {
80 self.0.pop()
81 }
82
83 pub fn iter(&self) -> std::slice::Iter<'_, W> {
85 self.0.iter()
86 }
87
88 pub fn reverse(&mut self) {
90 self.0.reverse();
91 }
92}
93
94impl<W: Clone + HasModulus> Index<usize> for Bundle<W> {
95 type Output = W;
96
97 fn index(&self, idx: usize) -> &Self::Output {
98 self.0.index(idx)
99 }
100}
101
102impl<F: Fancy> BundleGadgets for F {}
103impl<F: FancyArithmetic> ArithmeticBundleGadgets for F {}
104impl<F: FancyArithmetic + FancyProj> ArithmeticProjBundleGadgets for F {}
105impl<F: FancyBinary> BinaryBundleGadgets for F {}
106
107pub trait ArithmeticBundleGadgets: FancyArithmetic {
110 fn add_bundles(
117 &mut self,
118 x: &Bundle<Self::Item>,
119 y: &Bundle<Self::Item>,
120 ) -> Bundle<Self::Item> {
121 assert_eq!(
122 x.wires().len(),
123 y.wires().len(),
124 "`x` and `y` must be the same length"
125 );
126 Bundle::new(
127 x.wires()
128 .iter()
129 .zip(y.wires().iter())
130 .map(|(x, y)| self.add(x, y))
131 .collect::<Vec<Self::Item>>(),
132 )
133 }
134
135 fn sub_bundles(
142 &mut self,
143 x: &Bundle<Self::Item>,
144 y: &Bundle<Self::Item>,
145 ) -> Bundle<Self::Item> {
146 assert_eq!(
147 x.wires().len(),
148 y.wires().len(),
149 "`x` and `y` must be the same length"
150 );
151 Bundle::new(
152 x.wires()
153 .iter()
154 .zip(y.wires().iter())
155 .map(|(x, y)| self.sub(x, y))
156 .collect::<Vec<Self::Item>>(),
157 )
158 }
159
160 fn mul_bundles(
164 &mut self,
165 x: &Bundle<Self::Item>,
166 y: &Bundle<Self::Item>,
167 channel: &mut Channel,
168 ) -> swanky_error::Result<Bundle<Self::Item>> {
169 x.wires()
170 .iter()
171 .zip(y.wires().iter())
172 .map(|(x, y)| self.mul(x, y, channel))
173 .collect::<swanky_error::Result<_>>()
174 .map(Bundle::new)
175 }
176
177 fn mask(
179 &mut self,
180 b: &Self::Item,
181 x: &Bundle<Self::Item>,
182 channel: &mut Channel,
183 ) -> swanky_error::Result<Bundle<Self::Item>> {
184 x.wires()
185 .iter()
186 .map(|xwire| self.mul(xwire, b, channel))
187 .collect::<swanky_error::Result<_>>()
188 .map(Bundle)
189 }
190}
191
192pub trait ArithmeticProjBundleGadgets: FancyArithmetic + FancyProj {
194 fn mixed_radix_addition(
199 &mut self,
200 xs: &[Bundle<Self::Item>],
201 channel: &mut Channel,
202 ) -> swanky_error::Result<Bundle<Self::Item>> {
203 assert!(!xs.is_empty(), "`xs` cannot be empty");
204 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
205
206 let nargs = xs.len();
207 let n = xs[0].wires().len();
208
209 let mut digit_carry = None;
210 let mut carry_carry = None;
211 let mut max_carry = 0;
212
213 let mut res = Vec::with_capacity(n);
214
215 for i in 0..n {
216 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
218
219 let digit_sum = self.add_many(&ds);
221 let digit = digit_carry.map_or(digit_sum.clone(), |d| self.add(&digit_sum, &d));
222
223 if i < n - 1 {
224 let q = xs[0].wires()[i].modulus();
226 let max_val = nargs as u16 * (q - 1) + max_carry;
228 max_carry = max_val / q;
230
231 let modded_ds = ds
232 .iter()
233 .map(|d| self.mod_change(d, max_val + 1, channel))
234 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
235
236 let carry_sum = self.add_many(&modded_ds);
237 let carry = carry_carry.map_or(carry_sum.clone(), |c| self.add(&carry_sum, &c));
239
240 let next_mod = xs[0].wires()[i + 1].modulus();
243 let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
244 digit_carry = Some(self.proj(&carry, next_mod, Some(tt), channel)?);
245
246 let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
247
248 if i < n - 2 {
249 if max_carry < next_mod {
250 carry_carry = Some(self.mod_change(
251 digit_carry.as_ref().unwrap(),
252 next_max_val + 1,
253 channel,
254 )?);
255 } else {
256 let tt = (0..=max_val).map(|i| i / q).collect_vec();
257 carry_carry =
258 Some(self.proj(&carry, next_max_val + 1, Some(tt), channel)?);
259 }
260 } else {
261 carry_carry = None;
263 }
264 } else {
265 digit_carry = None;
266 carry_carry = None;
267 }
268 res.push(digit);
269 }
270 Ok(Bundle(res))
271 }
272
273 fn mixed_radix_addition_msb_only(
278 &mut self,
279 xs: &[Bundle<Self::Item>],
280 channel: &mut Channel,
281 ) -> swanky_error::Result<Self::Item> {
282 assert!(!xs.is_empty(), "`xs` cannot be empty");
283 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
284
285 let nargs = xs.len();
286 let n = xs[0].wires().len();
287
288 let mut opt_carry = None;
289 let mut max_carry = 0;
290
291 for i in 0..n - 1 {
292 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
294 let q = xs[0].moduli()[i];
296 let max_val = nargs as u16 * (q - 1) + max_carry;
298 max_carry = max_val / q;
300
301 let modded_ds = ds
304 .iter()
305 .map(|d| self.mod_change(d, max_val + 1, channel))
306 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
307 let sum = self.add_many(&modded_ds);
309 let sum_with_carry = opt_carry
311 .as_ref()
312 .map_or(sum.clone(), |c| self.add(&sum, c));
313
314 let next_mod = if i < n - 2 {
319 nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
320 } else {
321 xs[0].moduli()[i + 1] };
323
324 let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
325 opt_carry = Some(self.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
326 }
327
328 let ds = xs.iter().map(|x| x.wires()[n - 1].clone()).collect_vec();
330 let digit_sum = self.add_many(&ds);
331 Ok(opt_carry
332 .as_ref()
333 .map_or(digit_sum.clone(), |d| self.add(&digit_sum, d)))
334 }
335
336 fn eq_bundles(
341 &mut self,
342 x: &Bundle<Self::Item>,
343 y: &Bundle<Self::Item>,
344 channel: &mut Channel,
345 ) -> swanky_error::Result<Self::Item> {
346 assert_eq!(x.moduli(), y.moduli());
347
348 let wlen = x.wires().len() as u16;
349 let zs = x
350 .wires()
351 .iter()
352 .zip_eq(y.wires().iter())
353 .map(|(x, y)| {
354 let z = self.sub(x, y);
356 let mut eq_zero_tab = vec![0; x.modulus() as usize];
357 eq_zero_tab[0] = 1;
358 self.proj(&z, wlen + 1, Some(eq_zero_tab), channel)
359 })
360 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
361 let z = self.add_many(&zs);
363 let b = zs.len();
364 let mut tab = vec![0; b + 1];
365 tab[b] = 1;
366 self.proj(&z, 2, Some(tab), channel)
367 }
368}
369
370pub trait BinaryBundleGadgets: FancyBinary {
373 fn multiplex(
375 &mut self,
376 b: &Self::Item,
377 x: &Bundle<Self::Item>,
378 y: &Bundle<Self::Item>,
379 channel: &mut Channel,
380 ) -> swanky_error::Result<Bundle<Self::Item>> {
381 x.wires()
382 .iter()
383 .zip(y.wires().iter())
384 .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
385 .collect::<swanky_error::Result<_>>()
386 .map(Bundle)
387 }
388}
389
390pub trait BundleGadgets: Fancy {
393 fn encode_bundle(
395 &mut self,
396 values: &[u16],
397 moduli: &[u16],
398 channel: &mut Channel,
399 ) -> swanky_error::Result<Bundle<Self::Item>> {
400 self.encode_many(values, moduli, channel).map(Bundle::new)
401 }
402
403 fn receive_bundle(
405 &mut self,
406 moduli: &[u16],
407 channel: &mut Channel,
408 ) -> swanky_error::Result<Bundle<Self::Item>> {
409 self.receive_many(moduli, channel).map(Bundle::new)
410 }
411
412 fn encode_bundles(
417 &mut self,
418 values: &[Vec<u16>],
419 moduli: &[Vec<u16>],
420 channel: &mut Channel,
421 ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
422 let qs = moduli.iter().flatten().cloned().collect_vec();
423 let xs = values.iter().flatten().cloned().collect_vec();
424 assert_eq!(xs.len(), qs.len(), "unequal number of values and moduli");
425 let mut wires = self.encode_many(&xs, &qs, channel)?;
426 let buns = moduli
427 .iter()
428 .map(|qs| {
429 let ws = wires.drain(0..qs.len()).collect_vec();
430 Bundle::new(ws)
431 })
432 .collect_vec();
433 Ok(buns)
434 }
435
436 fn receive_many_bundles(
438 &mut self,
439 moduli: &[Vec<u16>],
440 channel: &mut Channel,
441 ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
442 let qs = moduli.iter().flatten().cloned().collect_vec();
443 let mut wires = self.receive_many(&qs, channel)?;
444 let buns = moduli
445 .iter()
446 .map(|qs| {
447 let ws = wires.drain(0..qs.len()).collect_vec();
448 Bundle::new(ws)
449 })
450 .collect_vec();
451 Ok(buns)
452 }
453
454 fn constant_bundle(
456 &mut self,
457 xs: &[u16],
458 ps: &[u16],
459 channel: &mut Channel,
460 ) -> swanky_error::Result<Bundle<Self::Item>> {
461 xs.iter()
462 .zip(ps.iter())
463 .map(|(&x, &p)| self.constant(x, p, channel))
464 .collect::<swanky_error::Result<_>>()
465 .map(Bundle)
466 }
467
468 fn output_bundle(
470 &mut self,
471 x: &Bundle<Self::Item>,
472 channel: &mut Channel,
473 ) -> swanky_error::Result<Option<Vec<u16>>> {
474 let ws = x.wires();
475 let mut outputs = Vec::with_capacity(ws.len());
476 for w in ws.iter() {
477 outputs.push(self.output(w, channel)?);
478 }
479 Ok(outputs.into_iter().collect())
480 }
481
482 fn output_bundles(
484 &mut self,
485 xs: &[Bundle<Self::Item>],
486 channel: &mut Channel,
487 ) -> swanky_error::Result<Option<Vec<Vec<u16>>>> {
488 let mut zs = Vec::with_capacity(xs.len());
489 for x in xs.iter() {
490 let z = self.output_bundle(x, channel)?;
491 zs.push(z);
492 }
493 Ok(zs.into_iter().collect())
494 }
495
496 fn shift(
502 &mut self,
503 x: &Bundle<Self::Item>,
504 n: usize,
505 channel: &mut Channel,
506 ) -> swanky_error::Result<Bundle<Self::Item>> {
507 let mut ws = x.wires().to_vec();
508 let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
509 for _ in 0..n {
510 ws.pop();
511 ws.insert(0, zero.clone());
512 }
513 Ok(Bundle(ws))
514 }
515
516 fn shift_extend(
519 &mut self,
520 x: &Bundle<Self::Item>,
521 n: usize,
522 channel: &mut Channel,
523 ) -> swanky_error::Result<Bundle<Self::Item>> {
524 let mut ws = x.wires().to_vec();
525 let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
526 for _ in 0..n {
527 ws.insert(0, zero.clone());
528 }
529 Ok(Bundle(ws))
530 }
531}