1use crate::{
2 FancyArithmetic, FancyBinary,
3 fancy::{Fancy, HasModulus},
4};
5use itertools::Itertools;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::ops::Index;
9use swanky_channel::Channel;
10
11#[derive(Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct Bundle<W>(Vec<W>);
15
16impl<W: Clone + HasModulus> Bundle<W> {
17 pub fn new(ws: Vec<W>) -> Bundle<W> {
19 Bundle(ws)
20 }
21
22 pub fn moduli(&self) -> Vec<u16> {
24 self.0.iter().map(HasModulus::modulus).collect()
25 }
26
27 pub fn wires(&self) -> &[W] {
29 &self.0
30 }
31
32 pub fn size(&self) -> usize {
34 self.0.len()
35 }
36
37 pub fn is_binary(&self) -> bool {
39 self.moduli().iter().all(|m| *m == 2)
40 }
41
42 pub fn with_moduli(&self, moduli: &[u16]) -> Bundle<W> {
44 let old_ws = self.wires();
45 let mut new_ws = Vec::with_capacity(moduli.len());
46 for &p in moduli {
47 if let Some(w) = old_ws.iter().find(|&x| x.modulus() == p) {
48 new_ws.push(w.clone());
49 } else {
50 panic!("Bundle::with_moduli: no {} modulus in bundle", p);
51 }
52 }
53 Bundle(new_ws)
54 }
55
56 pub fn pad(&mut self, val: &W, n: usize) {
58 for _ in 0..n {
59 self.0.push(val.clone());
60 }
61 }
62
63 pub fn extract(&mut self, wire_index: usize) -> W {
65 self.0.remove(wire_index)
66 }
67
68 pub fn insert(&mut self, wire_index: usize, val: W) {
70 self.0.insert(wire_index, val)
71 }
72
73 pub fn push(&mut self, val: W) {
75 self.0.push(val);
76 }
77
78 pub fn pop(&mut self) -> Option<W> {
80 self.0.pop()
81 }
82
83 pub fn iter(&self) -> std::slice::Iter<'_, W> {
85 self.0.iter()
86 }
87
88 pub fn reverse(&mut self) {
90 self.0.reverse();
91 }
92}
93
94impl<W: Clone + HasModulus> Index<usize> for Bundle<W> {
95 type Output = W;
96
97 fn index(&self, idx: usize) -> &Self::Output {
98 self.0.index(idx)
99 }
100}
101
102impl<F: Fancy> BundleGadgets for F {}
103impl<F: FancyArithmetic> ArithmeticBundleGadgets for F {}
104impl<F: FancyBinary> BinaryBundleGadgets for F {}
105
106pub trait ArithmeticBundleGadgets: FancyArithmetic {
109 fn add_bundles(
116 &mut self,
117 x: &Bundle<Self::Item>,
118 y: &Bundle<Self::Item>,
119 ) -> Bundle<Self::Item> {
120 assert_eq!(
121 x.wires().len(),
122 y.wires().len(),
123 "`x` and `y` must be the same length"
124 );
125 Bundle::new(
126 x.wires()
127 .iter()
128 .zip(y.wires().iter())
129 .map(|(x, y)| self.add(x, y))
130 .collect::<Vec<Self::Item>>(),
131 )
132 }
133
134 fn sub_bundles(
141 &mut self,
142 x: &Bundle<Self::Item>,
143 y: &Bundle<Self::Item>,
144 ) -> Bundle<Self::Item> {
145 assert_eq!(
146 x.wires().len(),
147 y.wires().len(),
148 "`x` and `y` must be the same length"
149 );
150 Bundle::new(
151 x.wires()
152 .iter()
153 .zip(y.wires().iter())
154 .map(|(x, y)| self.sub(x, y))
155 .collect::<Vec<Self::Item>>(),
156 )
157 }
158
159 fn mul_bundles(
163 &mut self,
164 x: &Bundle<Self::Item>,
165 y: &Bundle<Self::Item>,
166 channel: &mut Channel,
167 ) -> swanky_error::Result<Bundle<Self::Item>> {
168 x.wires()
169 .iter()
170 .zip(y.wires().iter())
171 .map(|(x, y)| self.mul(x, y, channel))
172 .collect::<swanky_error::Result<_>>()
173 .map(Bundle::new)
174 }
175
176 fn mixed_radix_addition(
181 &mut self,
182 xs: &[Bundle<Self::Item>],
183 channel: &mut Channel,
184 ) -> swanky_error::Result<Bundle<Self::Item>> {
185 assert!(!xs.is_empty(), "`xs` cannot be empty");
186 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
187
188 let nargs = xs.len();
189 let n = xs[0].wires().len();
190
191 let mut digit_carry = None;
192 let mut carry_carry = None;
193 let mut max_carry = 0;
194
195 let mut res = Vec::with_capacity(n);
196
197 for i in 0..n {
198 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
200
201 let digit_sum = self.add_many(&ds);
203 let digit = digit_carry.map_or(digit_sum.clone(), |d| self.add(&digit_sum, &d));
204
205 if i < n - 1 {
206 let q = xs[0].wires()[i].modulus();
208 let max_val = nargs as u16 * (q - 1) + max_carry;
210 max_carry = max_val / q;
212
213 let modded_ds = ds
214 .iter()
215 .map(|d| self.mod_change(d, max_val + 1, channel))
216 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
217
218 let carry_sum = self.add_many(&modded_ds);
219 let carry = carry_carry.map_or(carry_sum.clone(), |c| self.add(&carry_sum, &c));
221
222 let next_mod = xs[0].wires()[i + 1].modulus();
225 let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
226 digit_carry = Some(self.proj(&carry, next_mod, Some(tt), channel)?);
227
228 let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
229
230 if i < n - 2 {
231 if max_carry < next_mod {
232 carry_carry = Some(self.mod_change(
233 digit_carry.as_ref().unwrap(),
234 next_max_val + 1,
235 channel,
236 )?);
237 } else {
238 let tt = (0..=max_val).map(|i| i / q).collect_vec();
239 carry_carry =
240 Some(self.proj(&carry, next_max_val + 1, Some(tt), channel)?);
241 }
242 } else {
243 carry_carry = None;
245 }
246 } else {
247 digit_carry = None;
248 carry_carry = None;
249 }
250 res.push(digit);
251 }
252 Ok(Bundle(res))
253 }
254
255 fn mixed_radix_addition_msb_only(
260 &mut self,
261 xs: &[Bundle<Self::Item>],
262 channel: &mut Channel,
263 ) -> swanky_error::Result<Self::Item> {
264 assert!(!xs.is_empty(), "`xs` cannot be empty");
265 assert!(xs.iter().all(|x| x.moduli() == xs[0].moduli()));
266
267 let nargs = xs.len();
268 let n = xs[0].wires().len();
269
270 let mut opt_carry = None;
271 let mut max_carry = 0;
272
273 for i in 0..n - 1 {
274 let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
276 let q = xs[0].moduli()[i];
278 let max_val = nargs as u16 * (q - 1) + max_carry;
280 max_carry = max_val / q;
282
283 let modded_ds = ds
286 .iter()
287 .map(|d| self.mod_change(d, max_val + 1, channel))
288 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
289 let sum = self.add_many(&modded_ds);
291 let sum_with_carry = opt_carry
293 .as_ref()
294 .map_or(sum.clone(), |c| self.add(&sum, c));
295
296 let next_mod = if i < n - 2 {
301 nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
302 } else {
303 xs[0].moduli()[i + 1] };
305
306 let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
307 opt_carry = Some(self.proj(&sum_with_carry, next_mod, Some(tt), channel)?);
308 }
309
310 let ds = xs.iter().map(|x| x.wires()[n - 1].clone()).collect_vec();
312 let digit_sum = self.add_many(&ds);
313 Ok(opt_carry
314 .as_ref()
315 .map_or(digit_sum.clone(), |d| self.add(&digit_sum, d)))
316 }
317
318 fn mask(
320 &mut self,
321 b: &Self::Item,
322 x: &Bundle<Self::Item>,
323 channel: &mut Channel,
324 ) -> swanky_error::Result<Bundle<Self::Item>> {
325 x.wires()
326 .iter()
327 .map(|xwire| self.mul(xwire, b, channel))
328 .collect::<swanky_error::Result<_>>()
329 .map(Bundle)
330 }
331
332 fn eq_bundles(
337 &mut self,
338 x: &Bundle<Self::Item>,
339 y: &Bundle<Self::Item>,
340 channel: &mut Channel,
341 ) -> swanky_error::Result<Self::Item> {
342 assert_eq!(x.moduli(), y.moduli());
343
344 let wlen = x.wires().len() as u16;
345 let zs = x
346 .wires()
347 .iter()
348 .zip_eq(y.wires().iter())
349 .map(|(x, y)| {
350 let z = self.sub(x, y);
352 let mut eq_zero_tab = vec![0; x.modulus() as usize];
353 eq_zero_tab[0] = 1;
354 self.proj(&z, wlen + 1, Some(eq_zero_tab), channel)
355 })
356 .collect::<swanky_error::Result<Vec<Self::Item>>>()?;
357 let z = self.add_many(&zs);
359 let b = zs.len();
360 let mut tab = vec![0; b + 1];
361 tab[b] = 1;
362 self.proj(&z, 2, Some(tab), channel)
363 }
364}
365
366pub trait BinaryBundleGadgets: FancyBinary {
369 fn multiplex(
371 &mut self,
372 b: &Self::Item,
373 x: &Bundle<Self::Item>,
374 y: &Bundle<Self::Item>,
375 channel: &mut Channel,
376 ) -> swanky_error::Result<Bundle<Self::Item>> {
377 x.wires()
378 .iter()
379 .zip(y.wires().iter())
380 .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
381 .collect::<swanky_error::Result<_>>()
382 .map(Bundle)
383 }
384}
385
386pub trait BundleGadgets: Fancy {
389 fn constant_bundle(
391 &mut self,
392 xs: &[u16],
393 ps: &[u16],
394 channel: &mut Channel,
395 ) -> swanky_error::Result<Bundle<Self::Item>> {
396 xs.iter()
397 .zip(ps.iter())
398 .map(|(&x, &p)| self.constant(x, p, channel))
399 .collect::<swanky_error::Result<_>>()
400 .map(Bundle)
401 }
402
403 fn output_bundle(
405 &mut self,
406 x: &Bundle<Self::Item>,
407 channel: &mut Channel,
408 ) -> swanky_error::Result<Option<Vec<u16>>> {
409 let ws = x.wires();
410 let mut outputs = Vec::with_capacity(ws.len());
411 for w in ws.iter() {
412 outputs.push(self.output(w, channel)?);
413 }
414 Ok(outputs.into_iter().collect())
415 }
416
417 fn output_bundles(
419 &mut self,
420 xs: &[Bundle<Self::Item>],
421 channel: &mut Channel,
422 ) -> swanky_error::Result<Option<Vec<Vec<u16>>>> {
423 let mut zs = Vec::with_capacity(xs.len());
424 for x in xs.iter() {
425 let z = self.output_bundle(x, channel)?;
426 zs.push(z);
427 }
428 Ok(zs.into_iter().collect())
429 }
430
431 fn shift(
437 &mut self,
438 x: &Bundle<Self::Item>,
439 n: usize,
440 channel: &mut Channel,
441 ) -> swanky_error::Result<Bundle<Self::Item>> {
442 let mut ws = x.wires().to_vec();
443 let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
444 for _ in 0..n {
445 ws.pop();
446 ws.insert(0, zero.clone());
447 }
448 Ok(Bundle(ws))
449 }
450
451 fn shift_extend(
454 &mut self,
455 x: &Bundle<Self::Item>,
456 n: usize,
457 channel: &mut Channel,
458 ) -> swanky_error::Result<Bundle<Self::Item>> {
459 let mut ws = x.wires().to_vec();
460 let zero = self.constant(0, ws.last().unwrap().modulus(), channel)?;
461 for _ in 0..n {
462 ws.insert(0, zero.clone());
463 }
464 Ok(Bundle(ws))
465 }
466}