1use crate::{
2 FancyBinary,
3 fancy::{
4 HasModulus,
5 bundle::{Bundle, BundleGadgets},
6 },
7 util,
8};
9use itertools::Itertools;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use std::ops::{Deref, DerefMut};
13use swanky_channel::Channel;
14
15#[derive(Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct BinaryBundle<W>(Bundle<W>);
19
20impl<W: Clone + HasModulus> BinaryBundle<W> {
21 pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
23 BinaryBundle(Bundle::new(ws))
24 }
25
26 pub fn extract(self) -> Bundle<W> {
28 self.0
29 }
30}
31
32impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
33 type Target = Bundle<W>;
34
35 fn deref(&self) -> &Bundle<W> {
36 &self.0
37 }
38}
39
40impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
41 fn deref_mut(&mut self) -> &mut Bundle<W> {
42 &mut self.0
43 }
44}
45
46impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
47 fn from(b: Bundle<W>) -> BinaryBundle<W> {
48 debug_assert!(b.moduli().iter().all(|&p| p == 2));
49 BinaryBundle(b)
50 }
51}
52
53impl<F: FancyBinary> BinaryGadgets for F {}
54
55pub trait BinaryGadgets: FancyBinary + BundleGadgets {
57 fn bin_encode(
59 &mut self,
60 value: u128,
61 nbits: usize,
62 channel: &mut Channel,
63 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
64 let xs = util::u128_to_bits(value, nbits);
65 self.encode_bundle(&xs, &vec![2; nbits], channel)
66 .map(BinaryBundle::from)
67 }
68
69 fn bin_receive(
71 &mut self,
72 nbits: usize,
73 channel: &mut Channel,
74 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
75 self.receive_bundle(&vec![2; nbits], channel)
76 .map(BinaryBundle::from)
77 }
78
79 fn bin_encode_many(
81 &mut self,
82 values: &[u128],
83 nbits: usize,
84 channel: &mut Channel,
85 ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
86 let xs = values
87 .iter()
88 .flat_map(|x| util::u128_to_bits(*x, nbits))
89 .collect_vec();
90 let mut wires = self.encode_many(&xs, &vec![2; values.len() * nbits], channel)?;
91 let buns = (0..values.len())
92 .map(|_| {
93 let ws = wires.drain(0..nbits).collect_vec();
94 BinaryBundle::new(ws)
95 })
96 .collect_vec();
97 Ok(buns)
98 }
99
100 fn bin_receive_many(
102 &mut self,
103 ninputs: usize,
104 nbits: usize,
105 channel: &mut Channel,
106 ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
107 let mut wires = self.receive_many(&vec![2; ninputs * nbits], channel)?;
108 let buns = (0..ninputs)
109 .map(|_| {
110 let ws = wires.drain(0..nbits).collect_vec();
111 BinaryBundle::new(ws)
112 })
113 .collect_vec();
114 Ok(buns)
115 }
116
117 fn bin_constant_bundle(
119 &mut self,
120 val: u128,
121 nbits: usize,
122 channel: &mut Channel,
123 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
124 self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits], channel)
125 .map(BinaryBundle)
126 }
127
128 fn bin_output(
130 &mut self,
131 x: &BinaryBundle<Self::Item>,
132 channel: &mut Channel,
133 ) -> swanky_error::Result<Option<u128>> {
134 Ok(self
135 .output_bundle(x, channel)?
136 .map(|bs| util::u128_from_bits(&bs)))
137 }
138
139 fn bin_outputs(
141 &mut self,
142 xs: &[BinaryBundle<Self::Item>],
143 channel: &mut Channel,
144 ) -> swanky_error::Result<Option<Vec<u128>>> {
145 let mut zs = Vec::with_capacity(xs.len());
146 for x in xs.iter() {
147 let z = self.bin_output(x, channel)?;
148 zs.push(z);
149 }
150 Ok(zs.into_iter().collect())
151 }
152
153 fn bin_xor(
155 &mut self,
156 x: &BinaryBundle<Self::Item>,
157 y: &BinaryBundle<Self::Item>,
158 ) -> BinaryBundle<Self::Item> {
159 BinaryBundle::new(
160 x.wires()
161 .iter()
162 .zip(y.wires().iter())
163 .map(|(x, y)| self.xor(x, y))
164 .collect::<Vec<Self::Item>>(),
165 )
166 }
167
168 fn bin_and(
170 &mut self,
171 x: &BinaryBundle<Self::Item>,
172 y: &BinaryBundle<Self::Item>,
173 channel: &mut Channel,
174 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
175 x.wires()
176 .iter()
177 .zip(y.wires().iter())
178 .map(|(x, y)| self.and(x, y, channel))
179 .collect::<swanky_error::Result<Vec<Self::Item>>>()
180 .map(BinaryBundle::new)
181 }
182
183 fn bin_or(
185 &mut self,
186 x: &BinaryBundle<Self::Item>,
187 y: &BinaryBundle<Self::Item>,
188 channel: &mut Channel,
189 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
190 x.wires()
191 .iter()
192 .zip(y.wires().iter())
193 .map(|(x, y)| self.or(x, y, channel))
194 .collect::<swanky_error::Result<Vec<Self::Item>>>()
195 .map(BinaryBundle::new)
196 }
197
198 fn bin_addition(
203 &mut self,
204 xs: &BinaryBundle<Self::Item>,
205 ys: &BinaryBundle<Self::Item>,
206 channel: &mut Channel,
207 ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
208 assert_eq!(xs.moduli(), ys.moduli());
209 let xwires = xs.wires();
210 let ywires = ys.wires();
211 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
212 let mut bs = vec![z];
213 for i in 1..xwires.len() {
214 let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
215 z = res.0;
216 c = res.1;
217 bs.push(z);
218 }
219 Ok((BinaryBundle::new(bs), c))
220 }
221
222 fn bin_addition_no_carry(
227 &mut self,
228 xs: &BinaryBundle<Self::Item>,
229 ys: &BinaryBundle<Self::Item>,
230 channel: &mut Channel,
231 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
232 assert_eq!(xs.moduli(), ys.moduli());
233 let xwires = xs.wires();
234 let ywires = ys.wires();
235 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None, channel)?;
236 let mut bs = vec![z];
237 for i in 1..xwires.len() - 1 {
238 let res = self.adder(&xwires[i], &ywires[i], Some(&c), channel)?;
239 z = res.0;
240 c = res.1;
241 bs.push(z);
242 }
243 z = self.xor_many(&[
245 xwires.last().unwrap().clone(),
246 ywires.last().unwrap().clone(),
247 c,
248 ]);
249 bs.push(z);
250 Ok(BinaryBundle::new(bs))
251 }
252
253 fn bin_multiplication_lower_half(
261 &mut self,
262 xs: &BinaryBundle<Self::Item>,
263 ys: &BinaryBundle<Self::Item>,
264 channel: &mut Channel,
265 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
266 assert_eq!(xs.moduli(), ys.moduli());
267
268 let xwires = xs.wires();
269 let ywires = ys.wires();
270
271 let mut sum = xwires
272 .iter()
273 .map(|x| self.and(x, &ywires[0], channel))
274 .collect::<swanky_error::Result<Vec<Self::Item>>>()
275 .map(BinaryBundle::new)?;
276
277 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
278 let mul = xwires
279 .iter()
280 .map(|x| self.and(x, ywire, channel))
281 .collect::<swanky_error::Result<Vec<Self::Item>>>()
282 .map(BinaryBundle::new)?;
283 let shifted = self.shift(&mul, i, channel).map(BinaryBundle)?;
284 sum = self.bin_addition_no_carry(&sum, &shifted, channel)?;
285 }
286
287 Ok(sum)
288 }
289
290 fn bin_mul(
295 &mut self,
296 xs: &BinaryBundle<Self::Item>,
297 ys: &BinaryBundle<Self::Item>,
298 channel: &mut Channel,
299 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
300 assert_eq!(xs.moduli(), ys.moduli());
301
302 let xwires = xs.wires();
303 let ywires = ys.wires();
304
305 let mut sum = xwires
306 .iter()
307 .map(|x| self.and(x, &ywires[0], channel))
308 .collect::<Result<_, _>>()
309 .map(BinaryBundle::new)?;
310
311 let zero = self.constant(0, 2, channel)?;
312 sum.pad(&zero, 1);
313
314 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
315 let mul = xwires
316 .iter()
317 .map(|x| self.and(x, ywire, channel))
318 .collect::<Result<_, _>>()
319 .map(BinaryBundle::new)?;
320 let shifted = self
321 .shift_extend(&mul, i, channel)
322 .map(BinaryBundle::from)?;
323 let res = self.bin_addition(&sum, &shifted, channel)?;
324 sum = res.0;
325 sum.push(res.1);
326 }
327
328 Ok(sum)
329 }
330
331 fn bin_div(
336 &mut self,
337 xs: &BinaryBundle<Self::Item>,
338 ys: &BinaryBundle<Self::Item>,
339 channel: &mut Channel,
340 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
341 assert_eq!(xs.moduli(), ys.moduli());
342 let ys_neg = self.bin_twos_complement(ys, channel)?;
343 let mut acc = self.bin_constant_bundle(0, xs.size(), channel)?;
344 let mut qs = BinaryBundle::new(Vec::new());
345 for x in xs.iter().rev() {
346 acc.pop();
347 acc.insert(0, x.clone());
348 let (res, cout) = self.bin_addition(&acc, &ys_neg, channel)?;
349 acc = self.bin_multiplex(&cout, &acc, &res, channel)?;
350 qs.push(cout);
351 }
352 qs.reverse(); Ok(qs)
354 }
355
356 fn bin_twos_complement(
358 &mut self,
359 xs: &BinaryBundle<Self::Item>,
360 channel: &mut Channel,
361 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
362 let not_xs = BinaryBundle::new(
363 xs.wires()
364 .iter()
365 .map(|x| self.negate(x))
366 .collect::<Vec<_>>(),
367 );
368 let one = self.bin_constant_bundle(1, xs.size(), channel)?;
369 self.bin_addition_no_carry(¬_xs, &one, channel)
370 }
371
372 fn bin_subtraction(
376 &mut self,
377 xs: &BinaryBundle<Self::Item>,
378 ys: &BinaryBundle<Self::Item>,
379 channel: &mut Channel,
380 ) -> swanky_error::Result<(BinaryBundle<Self::Item>, Self::Item)> {
381 let neg_ys = self.bin_twos_complement(ys, channel)?;
382 self.bin_addition(xs, &neg_ys, channel)
383 }
384
385 fn bin_multiplex_constant_bits(
387 &mut self,
388 x: &Self::Item,
389 c1: u128,
390 c2: u128,
391 nbits: usize,
392 channel: &mut Channel,
393 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
394 let c1_bs = util::u128_to_bits(c1, nbits)
395 .into_iter()
396 .map(|x: u16| x > 0)
397 .collect_vec();
398 let c2_bs = util::u128_to_bits(c2, nbits)
399 .into_iter()
400 .map(|x: u16| x > 0)
401 .collect_vec();
402 c1_bs
403 .into_iter()
404 .zip(c2_bs)
405 .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2, channel))
406 .collect::<swanky_error::Result<Vec<Self::Item>>>()
407 .map(BinaryBundle::new)
408 }
409
410 fn bin_multiplex(
412 &mut self,
413 b: &Self::Item,
414 x: &BinaryBundle<Self::Item>,
415 y: &BinaryBundle<Self::Item>,
416 channel: &mut Channel,
417 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
418 x.wires()
419 .iter()
420 .zip(y.wires().iter())
421 .map(|(xwire, ywire)| self.mux(b, xwire, ywire, channel))
422 .collect::<swanky_error::Result<Vec<Self::Item>>>()
423 .map(BinaryBundle::new)
424 }
425
426 fn bin_cmul(
428 &mut self,
429 x: &BinaryBundle<Self::Item>,
430 c: u128,
431 nbits: usize,
432 channel: &mut Channel,
433 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
434 let zero = self.bin_constant_bundle(0, nbits, channel)?;
435 util::u128_to_bits(c, nbits)
436 .into_iter()
437 .enumerate()
438 .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
439 .try_fold(zero, |z, shift_amt| {
440 let s = self.shift(x, shift_amt, channel).map(BinaryBundle)?;
441 self.bin_addition_no_carry(&z, &s, channel)
442 })
443 }
444
445 fn bin_abs(
447 &mut self,
448 x: &BinaryBundle<Self::Item>,
449 channel: &mut Channel,
450 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
451 let sign = x.wires().last().unwrap();
452 let negated = self.bin_twos_complement(x, channel)?;
453 self.bin_multiplex(sign, x, &negated, channel)
454 }
455
456 fn bin_lt_signed(
458 &mut self,
459 x: &BinaryBundle<Self::Item>,
460 y: &BinaryBundle<Self::Item>,
461 channel: &mut Channel,
462 ) -> swanky_error::Result<Self::Item> {
463 let x_neg = &x.wires().last().unwrap();
465 let y_neg = &y.wires().last().unwrap();
466 let x_pos = self.negate(x_neg);
467 let y_pos = self.negate(y_neg);
468
469 let x_lt_y_unsigned = self.bin_lt(x, y, channel)?;
472
473 let tru = self.constant(1, 2, channel)?;
475 let x_neg_y_pos = self.and(x_neg, &y_pos, channel)?;
476 let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru, channel)?;
477
478 let fls = self.constant(0, 2, channel)?;
480 let x_pos_y_neg = self.and(&x_pos, y_neg, channel)?;
481 self.mux(&x_pos_y_neg, &r2, &fls, channel)
482 }
483
484 fn bin_lt(
486 &mut self,
487 x: &BinaryBundle<Self::Item>,
488 y: &BinaryBundle<Self::Item>,
489 channel: &mut Channel,
490 ) -> swanky_error::Result<Self::Item> {
491 let (_, lhs) = self.bin_subtraction(x, y, channel)?;
494
495 let y_contains_1 = self.or_many(y.wires(), channel)?;
499 let y_eq_0 = self.negate(&y_contains_1);
500
501 let x_contains_1 = self.or_many(x.wires(), channel)?;
503
504 let rhs = self.and(&y_eq_0, &x_contains_1, channel)?;
506
507 let geq = self.or(&lhs, &rhs, channel)?;
512 let ngeq = self.negate(&geq);
513
514 let xy_neq_0 = self.or(&y_contains_1, &x_contains_1, channel)?;
515 self.and(&xy_neq_0, &ngeq, channel)
516 }
517
518 fn bin_geq(
520 &mut self,
521 x: &BinaryBundle<Self::Item>,
522 y: &BinaryBundle<Self::Item>,
523 channel: &mut Channel,
524 ) -> swanky_error::Result<Self::Item> {
525 let z = self.bin_lt(x, y, channel)?;
526 Ok(self.negate(&z))
527 }
528
529 fn bin_max(
534 &mut self,
535 xs: &[BinaryBundle<Self::Item>],
536 channel: &mut Channel,
537 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
538 assert!(!xs.is_empty(), "`xs` cannot be empty");
539 xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
540 let pos = self.bin_lt(&x, y, channel)?;
541 let neg = self.negate(&pos);
542 Ok(BinaryBundle::new(
543 x.wires()
544 .iter()
545 .zip(y.wires().iter())
546 .map(|(x, y)| {
547 let xp = self.and(x, &neg, channel)?;
548 let yp = self.and(y, &pos, channel)?;
549 Ok(self.xor(&xp, &yp))
550 })
551 .collect::<swanky_error::Result<Vec<Self::Item>>>()?,
552 ))
553 })
554 }
555
556 fn bin_demux(
561 &mut self,
562 x: &BinaryBundle<Self::Item>,
563 channel: &mut Channel,
564 ) -> swanky_error::Result<Vec<Self::Item>> {
565 let wires = x.wires();
566 let nbits = wires.len();
567 assert!(nbits <= 8, "wire bitlength is too large");
568
569 let mut outs = Vec::with_capacity(1 << nbits);
570
571 for ix in 0..1 << nbits {
572 let mut acc = wires[0].clone();
573 if (ix & 1) == 0 {
574 acc = self.negate(&acc);
575 }
576 for (i, w) in wires.iter().enumerate().skip(1) {
577 if ((ix >> i) & 1) > 0 {
578 acc = self.and(&acc, w, channel)?;
579 } else {
580 let not_w = self.negate(w);
581 acc = self.and(&acc, ¬_w, channel)?;
582 }
583 }
584 outs.push(acc);
585 }
586
587 Ok(outs)
588 }
589
590 fn bin_rsa(&mut self, x: &BinaryBundle<Self::Item>, c: usize) -> BinaryBundle<Self::Item> {
592 self.bin_shr(x, c, x.wires().last().unwrap())
593 }
594
595 fn bin_rsl(
597 &mut self,
598 x: &BinaryBundle<Self::Item>,
599 c: usize,
600 channel: &mut Channel,
601 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
602 let zero = self.constant(0, 2, channel)?;
603 Ok(self.bin_shr(x, c, &zero))
604 }
605
606 fn bin_shr(
608 &mut self,
609 x: &BinaryBundle<Self::Item>,
610 c: usize,
611 pad: &Self::Item,
612 ) -> BinaryBundle<Self::Item> {
613 let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
614
615 for i in 0..x.wires().len() {
616 let src_idx = i + c;
617 if src_idx >= x.wires().len() {
618 wires.push(pad.clone())
619 } else {
620 wires.push(x.wires()[src_idx].clone())
621 }
622 }
623
624 BinaryBundle::new(wires)
625 }
626 fn bin_eq_bundles(
628 &mut self,
629 x: &BinaryBundle<Self::Item>,
630 y: &BinaryBundle<Self::Item>,
631 channel: &mut Channel,
632 ) -> swanky_error::Result<Self::Item> {
633 let zs = x
635 .wires()
636 .iter()
637 .zip_eq(y.wires().iter())
638 .map(|(x, y)| {
639 let xy = self.xor(x, y);
640 self.negate(&xy)
641 })
642 .collect::<Vec<_>>();
643 self.and_many(&zs, channel)
646 }
647}