use crate::error::FancyError;
use crate::fancy::bundle::{Bundle, BundleGadgets};
use crate::fancy::{Fancy, HasModulus};
use crate::util;
use itertools::Itertools;
use std::ops::Deref;
#[derive(Clone)]
pub struct BinaryBundle<W: Clone + HasModulus>(Bundle<W>);
impl<W: Clone + HasModulus> BinaryBundle<W> {
pub fn new(ws: Vec<W>) -> BinaryBundle<W> {
BinaryBundle(Bundle::new(ws))
}
pub fn extract(self) -> Bundle<W> {
self.0
}
}
impl<W: Clone + HasModulus> Deref for BinaryBundle<W> {
type Target = Bundle<W>;
fn deref(&self) -> &Bundle<W> {
&self.0
}
}
impl<W: Clone + HasModulus> From<Bundle<W>> for BinaryBundle<W> {
fn from(b: Bundle<W>) -> BinaryBundle<W> {
debug_assert!(b.moduli().iter().all(|&p| p == 2));
BinaryBundle(b)
}
}
impl<F: Fancy> BinaryGadgets for F {}
pub trait BinaryGadgets: Fancy + BundleGadgets {
fn bin_constant_bundle(
&mut self,
val: u128,
nbits: usize,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
self.constant_bundle(&util::u128_to_bits(val, nbits), &vec![2; nbits])
.map(BinaryBundle)
}
fn bin_xor(
&mut self,
x: &BinaryBundle<Self::Item>,
y: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
self.add_bundles(&x, &y).map(BinaryBundle)
}
fn bin_and(
&mut self,
x: &BinaryBundle<Self::Item>,
y: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
self.mul_bundles(&x, &y).map(BinaryBundle)
}
fn bin_or(
&mut self,
x: &BinaryBundle<Self::Item>,
y: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(x, y)| self.or(x, y))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)
}
fn bin_addition(
&mut self,
xs: &BinaryBundle<Self::Item>,
ys: &BinaryBundle<Self::Item>,
) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
if xs.moduli() != ys.moduli() {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let xwires = xs.wires();
let ywires = ys.wires();
let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
let mut bs = vec![z];
for i in 1..xwires.len() {
let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
z = res.0;
c = res.1;
bs.push(z);
}
Ok((BinaryBundle::new(bs), c))
}
fn bin_addition_no_carry(
&mut self,
xs: &BinaryBundle<Self::Item>,
ys: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
if xs.moduli() != ys.moduli() {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let xwires = xs.wires();
let ywires = ys.wires();
let (mut z, mut c) = self.adder(&xwires[0], &ywires[0], None)?;
let mut bs = vec![z];
for i in 1..xwires.len() - 1 {
let res = self.adder(&xwires[i], &ywires[i], Some(&c))?;
z = res.0;
c = res.1;
bs.push(z);
}
z = self.add_many(&[
xwires.last().unwrap().clone(),
ywires.last().unwrap().clone(),
c,
])?;
bs.push(z);
Ok(BinaryBundle::new(bs))
}
fn bin_multiplication_lower_half(
&mut self,
xs: &BinaryBundle<Self::Item>,
ys: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
if xs.moduli() != ys.moduli() {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let xwires = xs.wires();
let ywires = ys.wires();
let mut sum = xwires
.iter()
.map(|x| self.and(x, &ywires[0]))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)?;
for i in 1..xwires.len() {
let mul = xwires
.iter()
.map(|x| self.and(x, &ywires[i]))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)?;
let shifted = self.shift(&mul, i).map(BinaryBundle)?;
sum = self.bin_addition_no_carry(&sum, &shifted)?;
}
Ok(sum)
}
fn bin_twos_complement(
&mut self,
xs: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
let not_xs = xs
.wires()
.iter()
.map(|x| self.negate(x))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)?;
let one = self.bin_constant_bundle(1, xs.size())?;
self.bin_addition_no_carry(¬_xs, &one)
}
fn bin_subtraction(
&mut self,
xs: &BinaryBundle<Self::Item>,
ys: &BinaryBundle<Self::Item>,
) -> Result<(BinaryBundle<Self::Item>, Self::Item), Self::Error> {
let neg_ys = self.bin_twos_complement(&ys)?;
self.bin_addition(&xs, &neg_ys)
}
fn bin_multiplex_constant_bits(
&mut self,
x: &Self::Item,
c1: u128,
c2: u128,
nbits: usize,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
let c1_bs = util::u128_to_bits(c1, nbits)
.into_iter()
.map(|x: u16| x > 0)
.collect_vec();
let c2_bs = util::u128_to_bits(c2, nbits)
.into_iter()
.map(|x: u16| x > 0)
.collect_vec();
c1_bs
.into_iter()
.zip(c2_bs.into_iter())
.map(|(b1, b2)| self.mux_constant_bits(x, b1, b2))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)
}
fn bin_cmul(
&mut self,
x: &BinaryBundle<Self::Item>,
c: u128,
nbits: usize,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
let zero = self.bin_constant_bundle(0, nbits)?;
util::u128_to_bits(c, nbits)
.into_iter()
.enumerate()
.filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
.fold(Ok(zero), |z, shift_amt| {
let s = self.shift(x, shift_amt).map(BinaryBundle)?;
self.bin_addition_no_carry(&(z?), &s)
})
}
fn bin_abs(
&mut self,
x: &BinaryBundle<Self::Item>,
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
let sign = x.wires().last().unwrap();
let negated = self.bin_twos_complement(x)?;
self.multiplex(&sign, x, &negated).map(BinaryBundle)
}
fn bin_lt(
&mut self,
x: &BinaryBundle<Self::Item>,
y: &BinaryBundle<Self::Item>,
) -> Result<Self::Item, Self::Error> {
let (_, lhs) = self.bin_subtraction(x, y)?;
let y_contains_1 = self.or_many(y.wires())?;
let y_eq_0 = self.negate(&y_contains_1)?;
let x_contains_1 = self.or_many(x.wires())?;
let rhs = self.and(&y_eq_0, &x_contains_1)?;
let geq = self.or(&lhs, &rhs)?;
self.negate(&geq)
}
fn bin_geq(
&mut self,
x: &BinaryBundle<Self::Item>,
y: &BinaryBundle<Self::Item>,
) -> Result<Self::Item, Self::Error> {
let z = self.bin_lt(x, y)?;
self.negate(&z)
}
fn bin_max(
&mut self,
xs: &[BinaryBundle<Self::Item>],
) -> Result<BinaryBundle<Self::Item>, Self::Error> {
if xs.len() < 2 {
return Err(Self::Error::from(FancyError::InvalidArgNum {
got: xs.len(),
needed: 2,
}));
}
xs.iter().skip(1).fold(Ok(xs[0].clone()), |x, y| {
x.map(|x| {
let pos = self.bin_lt(&x, y)?;
let neg = self.negate(&pos)?;
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(x, y)| {
let xp = self.mul(x, &neg)?;
let yp = self.mul(y, &pos)?;
self.add(&xp, &yp)
})
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(BinaryBundle::new)
})?
})
}
fn bin_demux(
&mut self,
x: &BinaryBundle<Self::Item>
) -> Result<Vec<Self::Item>, Self::Error> {
let wires = x.wires();
let nbits = wires.len();
let mut outs = Vec::with_capacity(1<<nbits);
for ix in 0 .. 1<<nbits {
let mut acc = wires[0].clone();
if (ix & 1) == 0 {
acc = self.negate(&acc)?;
}
for (i,w) in wires.iter().enumerate().skip(1) {
if ((ix >> i) & 1) > 0 {
acc = self.and(&acc, w)?;
} else {
let not_w = self.negate(w)?;
acc = self.and(&acc, ¬_w)?;
}
}
outs.push(acc);
}
Ok(outs)
}
}