use crate::error::FancyError;
use crate::fancy::{Fancy, HasModulus};
use itertools::Itertools;
use std::ops::Index;
#[derive(Clone)]
pub struct Bundle<W: Clone + HasModulus>(Vec<W>);
impl<W: Clone + HasModulus> Bundle<W> {
pub fn new(ws: Vec<W>) -> Bundle<W> {
Bundle(ws)
}
pub fn moduli(&self) -> Vec<u16> {
self.0.iter().map(HasModulus::modulus).collect()
}
pub fn wires(&self) -> &[W] {
&self.0
}
pub fn size(&self) -> usize {
self.0.len()
}
pub fn is_binary(&self) -> bool {
self.moduli().iter().all(|m| *m == 2)
}
pub fn with_moduli(&self, moduli: &[u16]) -> Bundle<W> {
let old_ws = self.wires();
let mut new_ws = Vec::with_capacity(moduli.len());
for &p in moduli {
if let Some(w) = old_ws.iter().find(|&x| x.modulus() == p) {
new_ws.push(w.clone());
} else {
panic!("Bundle::with_moduli: no {} modulus in bundle", p);
}
}
Bundle(new_ws)
}
pub fn pad(&mut self, val: &W, n: usize) {
for _ in 0..n {
self.0.push(val.clone());
}
}
pub fn extract(&mut self, wire_index: usize) -> W {
self.0.remove(wire_index)
}
pub fn iter(&self) -> std::slice::Iter<W> {
self.0.iter()
}
}
impl<W: Clone + HasModulus> Index<usize> for Bundle<W> {
type Output = W;
fn index(&self, idx: usize) -> &Self::Output {
self.0.index(idx)
}
}
impl<F: Fancy> BundleGadgets for F {}
pub trait BundleGadgets: Fancy {
fn constant_bundle(
&mut self,
xs: &[u16],
ps: &[u16],
) -> Result<Bundle<Self::Item>, Self::Error> {
xs.iter()
.zip(ps.iter())
.map(|(&x, &p)| self.constant(x, p))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle)
}
fn output_bundle(&mut self, x: &Bundle<Self::Item>) -> Result<(), Self::Error> {
for w in x.wires() {
self.output(w)?;
}
Ok(())
}
fn output_bundles(&mut self, xs: &[Bundle<Self::Item>]) -> Result<(), Self::Error> {
for x in xs.iter() {
self.output_bundle(x)?;
}
Ok(())
}
fn add_bundles(
&mut self,
x: &Bundle<Self::Item>,
y: &Bundle<Self::Item>,
) -> Result<Bundle<Self::Item>, Self::Error> {
if x.wires().len() != y.wires().len() {
return Err(Self::Error::from(FancyError::InvalidArgNum {
got: y.wires().len(),
needed: x.wires().len(),
}));
}
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(x, y)| self.add(x, y))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle::new)
}
fn sub_bundles(
&mut self,
x: &Bundle<Self::Item>,
y: &Bundle<Self::Item>,
) -> Result<Bundle<Self::Item>, Self::Error> {
if x.wires().len() != y.wires().len() {
return Err(Self::Error::from(FancyError::InvalidArgNum {
got: y.wires().len(),
needed: x.wires().len(),
}));
}
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(x, y)| self.sub(x, y))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle::new)
}
fn mul_bundles(
&mut self,
x: &Bundle<Self::Item>,
y: &Bundle<Self::Item>,
) -> Result<Bundle<Self::Item>, Self::Error> {
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(x, y)| self.mul(x, y))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle::new)
}
fn mixed_radix_addition(
&mut self,
xs: &[Bundle<Self::Item>],
) -> Result<Bundle<Self::Item>, Self::Error> {
let nargs = xs.len();
let n = xs[0].wires().len();
if nargs < 2 {
return Err(Self::Error::from(FancyError::InvalidArgNum {
got: nargs,
needed: 2,
}));
}
if !xs.iter().all(|x| x.moduli() == xs[0].moduli()) {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let mut digit_carry = None;
let mut carry_carry = None;
let mut max_carry = 0;
let mut res = Vec::with_capacity(n);
for i in 0..n {
let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
let digit_sum = self.add_many(&ds)?;
let digit = digit_carry.map_or(Ok(digit_sum.clone()), |d| self.add(&digit_sum, &d))?;
if i < n - 1 {
let q = xs[0].wires()[i].modulus();
let max_val = nargs as u16 * (q - 1) + max_carry;
max_carry = max_val / q;
let modded_ds = ds
.iter()
.map(|d| self.mod_change(d, max_val + 1))
.collect::<Result<Vec<Self::Item>, Self::Error>>()?;
let carry_sum = self.add_many(&modded_ds)?;
let carry =
carry_carry.map_or(Ok(carry_sum.clone()), |c| self.add(&carry_sum, &c))?;
let next_mod = xs[0].wires()[i + 1].modulus();
let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
digit_carry = Some(self.proj(&carry, next_mod, Some(tt))?);
let next_max_val = nargs as u16 * (next_mod - 1) + max_carry;
if i < n - 2 {
if max_carry < next_mod {
carry_carry =
Some(self.mod_change(digit_carry.as_ref().unwrap(), next_max_val + 1)?);
} else {
let tt = (0..=max_val).map(|i| i / q).collect_vec();
carry_carry = Some(self.proj(&carry, next_max_val + 1, Some(tt))?);
}
} else {
carry_carry = None;
}
} else {
digit_carry = None;
carry_carry = None;
}
res.push(digit);
}
Ok(Bundle(res))
}
fn mixed_radix_addition_msb_only(
&mut self,
xs: &[Bundle<Self::Item>],
) -> Result<Self::Item, Self::Error> {
let nargs = xs.len();
let n = xs[0].wires().len();
if nargs < 2 {
return Err(Self::Error::from(FancyError::InvalidArgNum {
got: nargs,
needed: 2,
}));
}
if !xs.iter().all(|x| x.moduli() == xs[0].moduli()) {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let mut opt_carry = None;
let mut max_carry = 0;
for i in 0..n - 1 {
let ds = xs.iter().map(|x| x.wires()[i].clone()).collect_vec();
let q = xs[0].moduli()[i];
let max_val = nargs as u16 * (q - 1) + max_carry;
max_carry = max_val / q;
let modded_ds = ds
.iter()
.map(|d| self.mod_change(d, max_val + 1))
.collect::<Result<Vec<Self::Item>, Self::Error>>()?;
let sum = self.add_many(&modded_ds)?;
let sum_with_carry = opt_carry
.as_ref()
.map_or(Ok(sum.clone()), |c| self.add(&sum, &c))?;
let next_mod = if i < n - 2 {
nargs as u16 * (xs[0].moduli()[i + 1] - 1) + max_carry + 1
} else {
xs[0].moduli()[i + 1]
};
let tt = (0..=max_val).map(|i| (i / q) % next_mod).collect_vec();
opt_carry = Some(self.proj(&sum_with_carry, next_mod, Some(tt))?);
}
let ds = xs.iter().map(|x| x.wires()[n - 1].clone()).collect_vec();
let digit_sum = self.add_many(&ds)?;
opt_carry
.as_ref()
.map_or(Ok(digit_sum.clone()), |d| self.add(&digit_sum, &d))
}
fn multiplex(
&mut self,
b: &Self::Item,
x: &Bundle<Self::Item>,
y: &Bundle<Self::Item>,
) -> Result<Bundle<Self::Item>, Self::Error> {
x.wires()
.iter()
.zip(y.wires().iter())
.map(|(xwire, ywire)| self.mux(b, xwire, ywire))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle)
}
fn mask(
&mut self,
b: &Self::Item,
x: &Bundle<Self::Item>,
) -> Result<Bundle<Self::Item>, Self::Error> {
x.wires()
.iter()
.map(|xwire| self.mul(xwire, b))
.collect::<Result<Vec<Self::Item>, Self::Error>>()
.map(Bundle)
}
fn shift(
&mut self,
x: &Bundle<Self::Item>,
n: usize,
) -> Result<Bundle<Self::Item>, Self::Error> {
let mut ws = x.wires().to_vec();
let zero = self.constant(0, ws.last().unwrap().modulus())?;
for _ in 0..n {
ws.pop();
ws.insert(0, zero.clone());
}
Ok(Bundle(ws))
}
fn eq_bundles(
&mut self,
x: &Bundle<Self::Item>,
y: &Bundle<Self::Item>,
) -> Result<Self::Item, Self::Error> {
if x.moduli() != y.moduli() {
return Err(Self::Error::from(FancyError::UnequalModuli));
}
let wlen = x.wires().len() as u16;
let zs = x
.wires()
.iter()
.zip_eq(y.wires().iter())
.map(|(x, y)| {
let z = self.sub(x, y)?;
let mut eq_zero_tab = vec![0; x.modulus() as usize];
eq_zero_tab[0] = 1;
self.proj(&z, wlen + 1, Some(eq_zero_tab))
})
.collect::<Result<Vec<Self::Item>, Self::Error>>()?;
let z = self.add_many(&zs)?;
let b = zs.len();
let mut tab = vec![0; b + 1];
tab[b] = 1;
self.proj(&z, 2, Some(tab))
}
}