1use crate::{
2 FancyBinary,
3 fancy::{
4 HasModulus,
5 bundle::{Bundle, BundleGadgets},
6 },
7 util,
8};
9use itertools::Itertools;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use std::ops::{Deref, DerefMut};
13use swanky_channel::Channel;
14
15#[derive(Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct BinaryBundle<W>(Bundle<W>);
19
20impl<W: Clone + HasModulus> BinaryBundle<W> {
21 pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
23 BinaryBundle(Bundle::new(ws))
24 }
25
26 pub fn extract(self) -> Bundle<W> {
28 self.0
29 }
30}
31
32impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
33 type Target = Bundle<W>;
34
35 fn deref(&self) -> &Bundle<W> {
36 &self.0
37 }
38}
39
40impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
41 fn deref_mut(&mut self) -> &mut Bundle<W> {
42 &mut self.0
43 }
44}
45
46impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
47 fn from(b: Bundle<W>) -> BinaryBundle<W> {
48 debug_assert!(b.moduli().iter().all(|&p| p == 2));
49 BinaryBundle(b)
50 }
51}
52
53impl<F: FancyBinary> BinaryGadgets for F {}
54
55pub trait BinaryGadgets: FancyBinary + BundleGadgets {
57 fn bin_constant_bundle(
59 &mut self,
60 val: u128,
61 nbits: usize,
62 channel: &mut Channel,
63 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
64 self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits], channel)
65 .map(BinaryBundle)
66 }
67
68 fn bin_output(
70 &mut self,
71 x: &BinaryBundle<Self::Item>,
72 channel: &mut Channel,
73 ) -> swanky_error::Result<Option<u128>> {
74 Ok(self
75 .output_bundle(x, channel)?
76 .map(|bs| util::u128_from_bits(&bs)))
77 }
78
79 fn bin_outputs(
81 &mut self,
82 xs: &[BinaryBundle<Self::Item>],
83 channel: &mut Channel,
84 ) -> swanky_error::Result<Option<Vec<u128>>> {
85 let mut zs = Vec::with_capacity(xs.len());
86 for x in xs.iter() {
87 let z = self.bin_output(x, channel)?;
88 zs.push(z);
89 }
90 Ok(zs.into_iter().collect())
91 }
92
93 fn bin_xor(
95 &mut self,
96 x: &BinaryBundle<Self::Item>,
97 y: &BinaryBundle<Self::Item>,
98 ) -> BinaryBundle<Self::Item> {
99 BinaryBundle::new(
100 x.wires()
101 .iter()
102 .zip(y.wires().iter())
103 .map(|(x, y)| self.xor(x, y))
104 .collect::<Vec<Self::Item>>(),
105 )
106 }
107
108 fn bin_and(
110 &mut self,
111 x: &BinaryBundle<Self::Item>,
112 y: &BinaryBundle<Self::Item>,
113 channel: &mut Channel,
114 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
115 x.wires()
116 .iter()
117 .zip(y.wires().iter())
118 .map(|(x, y)| self.and(x, y, channel))
119 .collect::<swanky_error::Result<Vec<Self::Item>>>()
120 .map(BinaryBundle::new)
121 }
122
123 fn bin_or(
125 &mut self,
126 x: &BinaryBundle<Self::Item>,
127 y: &BinaryBundle<Self::Item>,
128 channel: &mut Channel,
129 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
130 x.wires()
131 .iter()
132 .zip(y.wires().iter())
133 .map(|(x, y)| self.or(x, y, channel))
134 .collect::<swanky_error::Result<Vec<Self::Item>>>()
135 .map(BinaryBundle::new)
136 }
137
138 fn bin_addition(
143 &mut self,
144 xs: &BinaryBundle<Self::Item>,
145 ys: &BinaryBundle<Self::Item>,
146 channel: &mut Channel,
147 ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
148 assert_eq!(xs.moduli(), ys.moduli());
149 let xwires = xs.wires();
150 let ywires = ys.wires();
151 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
152 let mut bs = vec![z];
153 for i in 1..xwires.len() {
154 let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
155 z = res.0;
156 c = res.1;
157 bs.push(z);
158 }
159 Ok((BinaryBundle::new(bs), c))
160 }
161
162 fn bin_addition_no_carry(
167 &mut self,
168 xs: &BinaryBundle<Self::Item>,
169 ys: &BinaryBundle<Self::Item>,
170 channel: &mut Channel,
171 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
172 assert_eq!(xs.moduli(), ys.moduli());
173 let xwires = xs.wires();
174 let ywires = ys.wires();
175 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
176 let mut bs = vec![z];
177 for i in 1..xwires.len() - 1 {
178 let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
179 z = res.0;
180 c = res.1;
181 bs.push(z);
182 }
183 z = self.xor_many(&[
185 xwires.last().unwrap().clone(),
186 ywires.last().unwrap().clone(),
187 c,
188 ]);
189 bs.push(z);
190 Ok(BinaryBundle::new(bs))
191 }
192
193 fn bin_multiplication_lower_half(
201 &mut self,
202 xs: &BinaryBundle<Self::Item>,
203 ys: &BinaryBundle<Self::Item>,
204 channel: &mut Channel,
205 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
206 assert_eq!(xs.moduli(), ys.moduli());
207
208 let xwires = xs.wires();
209 let ywires = ys.wires();
210
211 let mut sum = xwires
212 .iter()
213 .map(|x| self.and(x, &ywires[0], channel))
214 .collect::<swanky_error::Result<Vec<Self::Item>>>()
215 .map(BinaryBundle::new)?;
216
217 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
218 let mul = xwires
219 .iter()
220 .map(|x| self.and(x, ywire, channel))
221 .collect::<swanky_error::Result<Vec<Self::Item>>>()
222 .map(BinaryBundle::new)?;
223 let shifted = self.shift(&mul, i, channel).map(BinaryBundle)?;
224 sum = self.bin_addition_no_carry(&sum, &shifted, channel)?;
225 }
226
227 Ok(sum)
228 }
229
230 fn bin_mul(
235 &mut self,
236 xs: &BinaryBundle<Self::Item>,
237 ys: &BinaryBundle<Self::Item>,
238 channel: &mut Channel,
239 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
240 assert_eq!(xs.moduli(), ys.moduli());
241
242 let xwires = xs.wires();
243 let ywires = ys.wires();
244
245 let mut sum = xwires
246 .iter()
247 .map(|x| self.and(x, &ywires[0], channel))
248 .collect::<Result<_, _>>()
249 .map(BinaryBundle::new)?;
250
251 let zero = self.constant(0, 2, channel)?;
252 sum.pad(&zero, 1);
253
254 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
255 let mul = xwires
256 .iter()
257 .map(|x| self.and(x, ywire, channel))
258 .collect::<Result<_, _>>()
259 .map(BinaryBundle::new)?;
260 let shifted = self
261 .shift_extend(&mul, i, channel)
262 .map(BinaryBundle::from)?;
263 let res = self.bin_addition(&sum, &shifted, channel)?;
264 sum = res.0;
265 sum.push(res.1);
266 }
267
268 Ok(sum)
269 }
270
271 fn bin_div(
276 &mut self,
277 xs: &BinaryBundle<Self::Item>,
278 ys: &BinaryBundle<Self::Item>,
279 channel: &mut Channel,
280 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
281 assert_eq!(xs.moduli(), ys.moduli());
282 let ys_neg = self.bin_twos_complement(ys, channel)?;
283 let mut acc = self.bin_constant_bundle(0, xs.size(), channel)?;
284 let mut qs = BinaryBundle::new(Vec::new());
285 for x in xs.iter().rev() {
286 acc.pop();
287 acc.insert(0, x.clone());
288 let (res, cout) = self.bin_addition(&acc, &ys_neg, channel)?;
289 acc = self.bin_multiplex(&cout, &acc, &res, channel)?;
290 qs.push(cout);
291 }
292 qs.reverse(); Ok(qs)
294 }
295
296 fn bin_twos_complement(
298 &mut self,
299 xs: &BinaryBundle<Self::Item>,
300 channel: &mut Channel,
301 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
302 let not_xs = BinaryBundle::new(
303 xs.wires()
304 .iter()
305 .map(|x| self.negate(x))
306 .collect::<Vec<_>>(),
307 );
308 let one = self.bin_constant_bundle(1, xs.size(), channel)?;
309 self.bin_addition_no_carry(¬_xs, &one, channel)
310 }
311
312 fn bin_subtraction(
316 &mut self,
317 xs: &BinaryBundle<Self::Item>,
318 ys: &BinaryBundle<Self::Item>,
319 channel: &mut Channel,
320 ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
321 let neg_ys = self.bin_twos_complement(ys, channel)?;
322 self.bin_addition(xs, &neg_ys, channel)
323 }
324
325 fn bin_multiplex_constant_bits(
327 &mut self,
328 x: &Self::Item,
329 c1: u128,
330 c2: u128,
331 nbits: usize,
332 channel: &mut Channel,
333 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
334 let c1_bs = util::u128_to_bits(c1, nbits)
335 .into_iter()
336 .map(|x: u16| x > 0)
337 .collect_vec();
338 let c2_bs = util::u128_to_bits(c2, nbits)
339 .into_iter()
340 .map(|x: u16| x > 0)
341 .collect_vec();
342 c1_bs
343 .into_iter()
344 .zip(c2_bs)
345 .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2, channel))
346 .collect::<swanky_error::Result<Vec<Self::Item>>>()
347 .map(BinaryBundle::new)
348 }
349
350 fn bin_multiplex(
352 &mut self,
353 b: &Self::Item,
354 x: &BinaryBundle<Self::Item>,
355 y: &BinaryBundle<Self::Item>,
356 channel: &mut Channel,
357 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
358 x.wires()
359 .iter()
360 .zip(y.wires().iter())
361 .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
362 .collect::<swanky_error::Result<Vec<Self::Item>>>()
363 .map(BinaryBundle::new)
364 }
365
366 fn bin_cmul(
368 &mut self,
369 x: &BinaryBundle<Self::Item>,
370 c: u128,
371 nbits: usize,
372 channel: &mut Channel,
373 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
374 let zero = self.bin_constant_bundle(0, nbits, channel)?;
375 util::u128_to_bits(c, nbits)
376 .into_iter()
377 .enumerate()
378 .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
379 .try_fold(zero, |z, shift_amt| {
380 let s = self.shift(x, shift_amt, channel).map(BinaryBundle)?;
381 self.bin_addition_no_carry(&z, &s, channel)
382 })
383 }
384
385 fn bin_abs(
387 &mut self,
388 x: &BinaryBundle<Self::Item>,
389 channel: &mut Channel,
390 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
391 let sign = x.wires().last().unwrap();
392 let negated = self.bin_twos_complement(x, channel)?;
393 self.bin_multiplex(sign, x, &negated, channel)
394 }
395
396 fn bin_lt_signed(
398 &mut self,
399 x: &BinaryBundle<Self::Item>,
400 y: &BinaryBundle<Self::Item>,
401 channel: &mut Channel,
402 ) -> swanky_error::Result<Self::Item> {
403 let x_neg = &x.wires().last().unwrap();
405 let y_neg = &y.wires().last().unwrap();
406 let x_pos = self.negate(x_neg);
407 let y_pos = self.negate(y_neg);
408
409 let x_lt_y_unsigned = self.bin_lt(x, y, channel)?;
412
413 let tru = self.constant(1, 2, channel)?;
415 let x_neg_y_pos = self.and(x_neg, &y_pos, channel)?;
416 let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru, channel)?;
417
418 let fls = self.constant(0, 2, channel)?;
420 let x_pos_y_neg = self.and(&x_pos, y_neg, channel)?;
421 self.mux(&x_pos_y_neg, &r2, &fls, channel)
422 }
423
424 fn bin_lt(
426 &mut self,
427 x: &BinaryBundle<Self::Item>,
428 y: &BinaryBundle<Self::Item>,
429 channel: &mut Channel,
430 ) -> swanky_error::Result<Self::Item> {
431 let (_, lhs) = self.bin_subtraction(x, y, channel)?;
434
435 let y_contains_1 = self.or_many(y.wires(), channel)?;
439 let y_eq_0 = self.negate(&y_contains_1);
440
441 let x_contains_1 = self.or_many(x.wires(), channel)?;
443
444 let rhs = self.and(&y_eq_0, &x_contains_1, channel)?;
446
447 let geq = self.or(&lhs, &rhs, channel)?;
452 let ngeq = self.negate(&geq);
453
454 let xy_neq_0 = self.or(&y_contains_1, &x_contains_1, channel)?;
455 self.and(&xy_neq_0, &ngeq, channel)
456 }
457
458 fn bin_geq(
460 &mut self,
461 x: &BinaryBundle<Self::Item>,
462 y: &BinaryBundle<Self::Item>,
463 channel: &mut Channel,
464 ) -> swanky_error::Result<Self::Item> {
465 let z = self.bin_lt(x, y, channel)?;
466 Ok(self.negate(&z))
467 }
468
469 fn bin_max(
474 &mut self,
475 xs: &[BinaryBundle<Self::Item>],
476 channel: &mut Channel,
477 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
478 assert!(!xs.is_empty(), "`xs` cannot be empty");
479 xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
480 let pos = self.bin_lt(&x, y, channel)?;
481 let neg = self.negate(&pos);
482 Ok(BinaryBundle::new(
483 x.wires()
484 .iter()
485 .zip(y.wires().iter())
486 .map(|(x, y)| {
487 let xp = self.and(x, &neg, channel)?;
488 let yp = self.and(y, &pos, channel)?;
489 Ok(self.xor(&xp, &yp))
490 })
491 .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
492 ))
493 })
494 }
495
496 fn bin_demux(
501 &mut self,
502 x: &BinaryBundle<Self::Item>,
503 channel: &mut Channel,
504 ) -> swanky_error::Result<Vec<Self::Item>> {
505 let wires = x.wires();
506 let nbits = wires.len();
507 assert!(nbits <= 8, "wire bitlength is too large");
508
509 let mut outs = Vec::with_capacity(1 << nbits);
510
511 for ix in 0..1 << nbits {
512 let mut acc = wires[0].clone();
513 if (ix & 1) == 0 {
514 acc = self.negate(&acc);
515 }
516 for (i, w) in wires.iter().enumerate().skip(1) {
517 if ((ix >> i) & 1) > 0 {
518 acc = self.and(&acc, w, channel)?;
519 } else {
520 let not_w = self.negate(w);
521 acc = self.and(&acc, ¬_w, channel)?;
522 }
523 }
524 outs.push(acc);
525 }
526
527 Ok(outs)
528 }
529
530 fn bin_rsa(&mut self, x: &BinaryBundle<Self::Item>, c: usize) -> BinaryBundle<Self::Item> {
532 self.bin_shr(x, c, x.wires().last().unwrap())
533 }
534
535 fn bin_rsl(
537 &mut self,
538 x: &BinaryBundle<Self::Item>,
539 c: usize,
540 channel: &mut Channel,
541 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
542 let zero = self.constant(0, 2, channel)?;
543 Ok(self.bin_shr(x, c, &zero))
544 }
545
546 fn bin_shr(
548 &mut self,
549 x: &BinaryBundle<Self::Item>,
550 c: usize,
551 pad: &Self::Item,
552 ) -> BinaryBundle<Self::Item> {
553 let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
554
555 for i in 0..x.wires().len() {
556 let src_idx = i + c;
557 if src_idx >= x.wires().len() {
558 wires.push(pad.clone())
559 } else {
560 wires.push(x.wires()[src_idx].clone())
561 }
562 }
563
564 BinaryBundle::new(wires)
565 }
566 fn bin_eq_bundles(
568 &mut self,
569 x: &BinaryBundle<Self::Item>,
570 y: &BinaryBundle<Self::Item>,
571 channel: &mut Channel,
572 ) -> swanky_error::Result<Self::Item> {
573 let zs = x
575 .wires()
576 .iter()
577 .zip_eq(y.wires().iter())
578 .map(|(x, y)| {
579 let xy = self.xor(x, y);
580 self.negate(&xy)
581 })
582 .collect::<Vec<_>>();
583 self.and_many(&zs, channel)
586 }
587}