1use crate::{
2 FancyBinary,
3 errors::FancyError,
4 fancy::{
5 HasModulus,
6 bundle::{Bundle, BundleGadgets},
7 },
8 util,
9};
10use itertools::Itertools;
11use std::ops::{Deref, DerefMut};
12
13#[derive(Clone)]
15pub struct BinaryBundle<W>(Bundle<W>);
16
17impl<W: Clone + HasModulus> BinaryBundle<W> {
18 pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
20 BinaryBundle(Bundle::new(ws))
21 }
22
23 pub fn extract(self) -> Bundle<W> {
25 self.0
26 }
27}
28
29impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
30 type Target = Bundle<W>;
31
32 fn deref(&self) -> &Bundle<W> {
33 &self.0
34 }
35}
36
37impl<W: Clone + HasModulus> DerefMut for BinaryBundle<W> {
38 fn deref_mut(&mut self) -> &mut Bundle<W> {
39 &mut self.0
40 }
41}
42
43impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
44 fn from(b: Bundle<W>) -> BinaryBundle<W> {
45 debug_assert!(b.moduli().iter().all(|&p| p == 2));
46 BinaryBundle(b)
47 }
48}
49
50impl<F: FancyBinary> BinaryGadgets for F {}
51
52pub trait BinaryGadgets: FancyBinary + BundleGadgets {
54 fn bin_constant_bundle(
56 &mut self,
57 val: u128,
58 nbits: usize,
59 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
60 self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits])
61 .map(BinaryBundle)
62 }
63
64 fn bin_output(&mut self, x: &BinaryBundle<Self::Item>) -> Result<Option<u128>, Self::Error> {
66 Ok(self.output_bundle(x)?.map(|bs| util::u128_from_bits(&bs)))
67 }
68
69 fn bin_outputs(
71 &mut self,
72 xs: &[BinaryBundle<Self::Item>],
73 ) -> Result<Option<Vec<u128>>, Self::Error> {
74 let mut zs = Vec::with_capacity(xs.len());
75 for x in xs.iter() {
76 let z = self.bin_output(x)?;
77 zs.push(z);
78 }
79 Ok(zs.into_iter().collect())
80 }
81
82 fn bin_xor(
84 &mut self,
85 x: &BinaryBundle<Self::Item>,
86 y: &BinaryBundle<Self::Item>,
87 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
88 x.wires()
89 .iter()
90 .zip(y.wires().iter())
91 .map(|(x, y)| self.xor(x, y))
92 .collect::<Result<Vec<Self::Item>, Self::Error>>()
93 .map(BinaryBundle::new)
94 }
95
96 fn bin_and(
98 &mut self,
99 x: &BinaryBundle<Self::Item>,
100 y: &BinaryBundle<Self::Item>,
101 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
102 x.wires()
103 .iter()
104 .zip(y.wires().iter())
105 .map(|(x, y)| self.and(x, y))
106 .collect::<Result<Vec<Self::Item>, Self::Error>>()
107 .map(BinaryBundle::new)
108 }
109
110 fn bin_or(
112 &mut self,
113 x: &BinaryBundle<Self::Item>,
114 y: &BinaryBundle<Self::Item>,
115 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
116 x.wires()
117 .iter()
118 .zip(y.wires().iter())
119 .map(|(x, y)| self.or(x, y))
120 .collect::<Result<Vec<Self::Item>, Self::Error>>()
121 .map(BinaryBundle::new)
122 }
123
124 fn bin_addition(
126 &mut self,
127 xs: &BinaryBundle<Self::Item>,
128 ys: &BinaryBundle<Self::Item>,
129 ) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
130 if xs.moduli() != ys.moduli() {
131 return Err(Self::Error::from(FancyError::UnequalModuli));
132 }
133 let xwires = xs.wires();
134 let ywires = ys.wires();
135 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
136 let mut bs = vec![z];
137 for i in 1..xwires.len() {
138 let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
139 z = res.0;
140 c = res.1;
141 bs.push(z);
142 }
143 Ok((BinaryBundle::new(bs), c))
144 }
145
146 fn bin_addition_no_carry(
148 &mut self,
149 xs: &BinaryBundle<Self::Item>,
150 ys: &BinaryBundle<Self::Item>,
151 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
152 if xs.moduli() != ys.moduli() {
153 return Err(Self::Error::from(FancyError::UnequalModuli));
154 }
155 let xwires = xs.wires();
156 let ywires = ys.wires();
157 let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
158 let mut bs = vec![z];
159 for i in 1..xwires.len() - 1 {
160 let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
161 z = res.0;
162 c = res.1;
163 bs.push(z);
164 }
165 z = self.xor_many(&[
167 xwires.last().unwrap().clone(),
168 ywires.last().unwrap().clone(),
169 c,
170 ])?;
171 bs.push(z);
172 Ok(BinaryBundle::new(bs))
173 }
174
175 fn bin_multiplication_lower_half(
180 &mut self,
181 xs: &BinaryBundle<Self::Item>,
182 ys: &BinaryBundle<Self::Item>,
183 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
184 if xs.moduli() != ys.moduli() {
185 return Err(Self::Error::from(FancyError::UnequalModuli));
186 }
187
188 let xwires = xs.wires();
189 let ywires = ys.wires();
190
191 let mut sum = xwires
192 .iter()
193 .map(|x| self.and(x, &ywires[0]))
194 .collect::<Result<Vec<Self::Item>, Self::Error>>()
195 .map(BinaryBundle::new)?;
196
197 for i in 1..xwires.len() {
198 let mul = xwires
199 .iter()
200 .map(|x| self.and(x, &ywires[i]))
201 .collect::<Result<Vec<Self::Item>, Self::Error>>()
202 .map(BinaryBundle::new)?;
203 let shifted = self.shift(&mul, i).map(BinaryBundle)?;
204 sum = self.bin_addition_no_carry(&sum, &shifted)?;
205 }
206
207 Ok(sum)
208 }
209
210 fn bin_mul(
212 &mut self,
213 xs: &BinaryBundle<Self::Item>,
214 ys: &BinaryBundle<Self::Item>,
215 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
216 if xs.moduli() != ys.moduli() {
217 return Err(Self::Error::from(FancyError::UnequalModuli));
218 }
219
220 let xwires = xs.wires();
221 let ywires = ys.wires();
222
223 let mut sum = xwires
224 .iter()
225 .map(|x| self.and(x, &ywires[0]))
226 .collect::<Result<_, _>>()
227 .map(BinaryBundle::new)?;
228
229 let zero = self.constant(0, 2)?;
230 sum.pad(&zero, 1);
231
232 for i in 1..xwires.len() {
233 let mul = xwires
234 .iter()
235 .map(|x| self.and(x, &ywires[i]))
236 .collect::<Result<_, _>>()
237 .map(BinaryBundle::new)?;
238 let shifted = self.shift_extend(&mul, i).map(BinaryBundle::from)?;
239 let res = self.bin_addition(&sum, &shifted)?;
240 sum = res.0;
241 sum.push(res.1);
242 }
243
244 Ok(sum)
245 }
246
247 fn bin_div(
249 &mut self,
250 xs: &BinaryBundle<Self::Item>,
251 ys: &BinaryBundle<Self::Item>,
252 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
253 if xs.moduli() != ys.moduli() {
254 return Err(Self::Error::from(FancyError::UnequalModuli));
255 }
256 let ys_neg = self.bin_twos_complement(ys)?;
257 let mut acc = self.bin_constant_bundle(0, xs.size())?;
258 let mut qs = BinaryBundle::new(Vec::new());
259 for x in xs.iter().rev() {
260 acc.pop();
261 acc.insert(0, x.clone());
262 let (res, cout) = self.bin_addition(&acc, &ys_neg)?;
263 acc = self.bin_multiplex(&cout, &acc, &res)?;
264 qs.push(cout);
265 }
266 qs.reverse(); Ok(qs)
268 }
269
270 fn bin_twos_complement(
272 &mut self,
273 xs: &BinaryBundle<Self::Item>,
274 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
275 let not_xs = xs
276 .wires()
277 .iter()
278 .map(|x| self.negate(x))
279 .collect::<Result<Vec<Self::Item>, Self::Error>>()
280 .map(BinaryBundle::new)?;
281 let one = self.bin_constant_bundle(1, xs.size())?;
282 self.bin_addition_no_carry(¬_xs, &one)
283 }
284
285 fn bin_subtraction(
289 &mut self,
290 xs: &BinaryBundle<Self::Item>,
291 ys: &BinaryBundle<Self::Item>,
292 ) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
293 let neg_ys = self.bin_twos_complement(ys)?;
294 self.bin_addition(xs, &neg_ys)
295 }
296
297 fn bin_multiplex_constant_bits(
299 &mut self,
300 x: &Self::Item,
301 c1: u128,
302 c2: u128,
303 nbits: usize,
304 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
305 let c1_bs = util::u128_to_bits(c1, nbits)
306 .into_iter()
307 .map(|x: u16| x > 0)
308 .collect_vec();
309 let c2_bs = util::u128_to_bits(c2, nbits)
310 .into_iter()
311 .map(|x: u16| x > 0)
312 .collect_vec();
313 c1_bs
314 .into_iter()
315 .zip(c2_bs.into_iter())
316 .map(|(b1, b2)| self.mux_constant_bits(x, b1, b2))
317 .collect::<Result<Vec<Self::Item>, Self::Error>>()
318 .map(BinaryBundle::new)
319 }
320
321 fn bin_multiplex(
323 &mut self,
324 b: &Self::Item,
325 x: &BinaryBundle<Self::Item>,
326 y: &BinaryBundle<Self::Item>,
327 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
328 x.wires()
329 .iter()
330 .zip(y.wires().iter())
331 .map(|(xwire, ywire)| self.mux(b, xwire, ywire))
332 .collect::<Result<Vec<Self::Item>, Self::Error>>()
333 .map(BinaryBundle::new)
334 }
335
336 fn bin_cmul(
338 &mut self,
339 x: &BinaryBundle<Self::Item>,
340 c: u128,
341 nbits: usize,
342 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
343 let zero = self.bin_constant_bundle(0, nbits)?;
344 util::u128_to_bits(c, nbits)
345 .into_iter()
346 .enumerate()
347 .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
348 .fold(Ok(zero), |z, shift_amt| {
349 let s = self.shift(x, shift_amt).map(BinaryBundle)?;
350 self.bin_addition_no_carry(&(z?), &s)
351 })
352 }
353
354 fn bin_abs(
356 &mut self,
357 x: &BinaryBundle<Self::Item>,
358 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
359 let sign = x.wires().last().unwrap();
360 let negated = self.bin_twos_complement(x)?;
361 self.bin_multiplex(sign, x, &negated)
362 }
363
364 fn bin_lt_signed(
366 &mut self,
367 x: &BinaryBundle<Self::Item>,
368 y: &BinaryBundle<Self::Item>,
369 ) -> Result<Self::Item, Self::Error> {
370 let x_neg = &x.wires().last().unwrap();
372 let y_neg = &y.wires().last().unwrap();
373 let x_pos = self.negate(x_neg)?;
374 let y_pos = self.negate(y_neg)?;
375
376 let x_lt_y_unsigned = self.bin_lt(x, y)?;
379
380 let tru = self.constant(1, 2)?;
382 let x_neg_y_pos = self.and(x_neg, &y_pos)?;
383 let r2 = self.mux(&x_neg_y_pos, &x_lt_y_unsigned, &tru)?;
384
385 let fls = self.constant(0, 2)?;
387 let x_pos_y_neg = self.and(&x_pos, y_neg)?;
388 self.mux(&x_pos_y_neg, &r2, &fls)
389 }
390
391 fn bin_lt(
393 &mut self,
394 x: &BinaryBundle<Self::Item>,
395 y: &BinaryBundle<Self::Item>,
396 ) -> Result<Self::Item, Self::Error> {
397 let (_, lhs) = self.bin_subtraction(x, y)?;
400
401 let y_contains_1 = self.or_many(y.wires())?;
405 let y_eq_0 = self.negate(&y_contains_1)?;
406
407 let x_contains_1 = self.or_many(x.wires())?;
409
410 let rhs = self.and(&y_eq_0, &x_contains_1)?;
412
413 let geq = self.or(&lhs, &rhs)?;
418 let ngeq = self.negate(&geq)?;
419
420 let xy_neq_0 = self.or(&y_contains_1, &x_contains_1)?;
421 self.and(&xy_neq_0, &ngeq)
422 }
423
424 fn bin_geq(
426 &mut self,
427 x: &BinaryBundle<Self::Item>,
428 y: &BinaryBundle<Self::Item>,
429 ) -> Result<Self::Item, Self::Error> {
430 let z = self.bin_lt(x, y)?;
431 self.negate(&z)
432 }
433
434 fn bin_max(
436 &mut self,
437 xs: &[BinaryBundle<Self::Item>],
438 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
439 if xs.is_empty() {
440 return Err(Self::Error::from(FancyError::InvalidArgNum {
441 got: xs.len(),
442 needed: 1,
443 }));
444 }
445 xs.iter().skip(1).fold(Ok(xs[0].clone()), |x, y| {
446 x.map(|x| {
447 let pos = self.bin_lt(&x, y)?;
448 let neg = self.negate(&pos)?;
449 x.wires()
450 .iter()
451 .zip(y.wires().iter())
452 .map(|(x, y)| {
453 let xp = self.and(x, &neg)?;
454 let yp = self.and(y, &pos)?;
455 self.xor(&xp, &yp)
456 })
457 .collect::<Result<Vec<Self::Item>, Self::Error>>()
458 .map(BinaryBundle::new)
459 })?
460 })
461 }
462
463 fn bin_demux(&mut self, x: &BinaryBundle<Self::Item>) -> Result<Vec<Self::Item>, Self::Error> {
465 let wires = x.wires();
466 let nbits = wires.len();
467 if nbits > 8 {
468 return Err(Self::Error::from(FancyError::InvalidArg(
469 "wire bitlength too large".to_string(),
470 )));
471 }
472
473 let mut outs = Vec::with_capacity(1 << nbits);
474
475 for ix in 0..1 << nbits {
476 let mut acc = wires[0].clone();
477 if (ix & 1) == 0 {
478 acc = self.negate(&acc)?;
479 }
480 for (i, w) in wires.iter().enumerate().skip(1) {
481 if ((ix >> i) & 1) > 0 {
482 acc = self.and(&acc, w)?;
483 } else {
484 let not_w = self.negate(w)?;
485 acc = self.and(&acc, ¬_w)?;
486 }
487 }
488 outs.push(acc);
489 }
490
491 Ok(outs)
492 }
493
494 fn bin_rsa(
496 &mut self,
497 x: &BinaryBundle<Self::Item>,
498 c: usize,
499 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
500 self.bin_shr(x, c, x.wires().last().unwrap())
501 }
502
503 fn bin_rsl(
505 &mut self,
506 x: &BinaryBundle<Self::Item>,
507 c: usize,
508 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
509 let zero = self.constant(0, 2)?;
510 self.bin_shr(x, c, &zero)
511 }
512
513 fn bin_shr(
515 &mut self,
516 x: &BinaryBundle<Self::Item>,
517 c: usize,
518 pad: &Self::Item,
519 ) -> Result<BinaryBundle<Self::Item>, Self::Error> {
520 let mut wires: Vec<Self::Item> = Vec::with_capacity(x.wires().len());
521
522 for i in 0..x.wires().len() {
523 let src_idx = i + c;
524 if src_idx >= x.wires().len() {
525 wires.push(pad.clone())
526 } else {
527 wires.push(x.wires()[src_idx].clone())
528 }
529 }
530
531 Ok(BinaryBundle::new(wires))
532 }
533 fn bin_eq_bundles(
535 &mut self,
536 x: &BinaryBundle<Self::Item>,
537 y: &BinaryBundle<Self::Item>,
538 ) -> Result<Self::Item, Self::Error> {
539 let zs = x
541 .wires()
542 .iter()
543 .zip_eq(y.wires().iter())
544 .map(|(x, y)| {
545 let xy = self.xor(x, y)?;
546 self.negate(&xy)
547 })
548 .collect::<Result<Vec<Self::Item>, Self::Error>>()?;
549 self.and_many(&zs)
552 }
553}