Skip to main content

fancy_garbling/circuits/binary/
binary_multiplex.rs

1use std::marker::PhantomData;
2
3use crate::{
4    BinaryBundle, FancyBinary,
5    circuit::Circuit,
6    circuits::binary::{Mux, MuxConstants},
7    util::u128_to_bits,
8};
9use swanky_channel::Channel;
10use swanky_error::Result;
11
12/// For bit `b` and [`BinaryBundle`]s `x` and `y`, output `x` if `b == 0`, and
13/// `y` otherwise.
14#[derive(Default)]
15pub struct BinaryMultiplex<'a>(PhantomData<&'a ()>);
16
17impl<'a> BinaryMultiplex<'a> {
18    /// Create a new [`BinaryMultiplex`] circuit.
19    pub fn new() -> Self {
20        Default::default()
21    }
22}
23
24impl<'a, F: FancyBinary> Circuit<F> for BinaryMultiplex<'a>
25where
26    F::Item: 'a,
27{
28    type Input = (
29        F::Item,
30        &'a BinaryBundle<F::Item>,
31        &'a BinaryBundle<F::Item>,
32    );
33    type Output = BinaryBundle<F::Item>;
34
35    fn execute(
36        &self,
37        backend: &mut F,
38        inputs: Self::Input,
39        channel: &mut Channel,
40    ) -> Result<Self::Output> {
41        let (b, xs, ys) = inputs;
42        xs.wires()
43            .iter()
44            .zip(ys.wires().iter())
45            .map(|(x, y)| Mux::new().execute(backend, (&b.clone(), x, y), channel))
46            .collect::<Result<Vec<_>>>()
47            .map(BinaryBundle::new)
48    }
49}
50
51/// For bit `b` and constants `c1` and `c2` of bitlength `n`, output `c1` if `b
52/// == 0` and `c2` otherwise.
53pub struct BinaryMultiplexConstantBits;
54
55impl<F: FancyBinary> Circuit<F> for BinaryMultiplexConstantBits {
56    type Input = (F::Item, u128, u128, usize);
57    type Output = BinaryBundle<F::Item>;
58
59    fn execute(
60        &self,
61        backend: &mut F,
62        inputs: Self::Input,
63        channel: &mut Channel,
64    ) -> Result<Self::Output> {
65        let (b, c1, c2, nbits) = inputs;
66
67        let c1_bs = u128_to_bits(c1, nbits)
68            .into_iter()
69            .map(|x: u16| x > 0)
70            .collect::<Vec<_>>();
71        let c2_bs = u128_to_bits(c2, nbits)
72            .into_iter()
73            .map(|x: u16| x > 0)
74            .collect::<Vec<_>>();
75        c1_bs
76            .into_iter()
77            .zip(c2_bs)
78            .map(|(b1, b2)| MuxConstants::new().execute(backend, (&b.clone(), b1, b2), channel))
79            .collect::<Result<_>>()
80            .map(BinaryBundle::new)
81    }
82}
83
84#[cfg(test)]
85mod test {
86    use super::BinaryMultiplex;
87    use crate::{
88        circuits::binary::BinaryMultiplexConstantBits,
89        dummy::{Dummy, DummyVal},
90    };
91    use rand::Rng;
92
93    #[test]
94    fn binary_multiplex() {
95        let mut rng = rand::thread_rng();
96        let nbits = 1 + (rng.r#gen::<usize>() % 200);
97        let x = rng.r#gen::<u128>() % (nbits as u128);
98        let y = rng.r#gen::<u128>() % (nbits as u128);
99        let x_inputs = DummyVal::to_binary(x, nbits);
100        let y_inputs = DummyVal::to_binary(y, nbits);
101
102        for b in 0..=1 {
103            let output = Dummy::eval(
104                &BinaryMultiplex::new(),
105                (DummyVal::new(b, 2), &x_inputs, &y_inputs),
106            )
107            .unwrap();
108            assert_eq!(DummyVal::from_binary(&output), if b == 0 { x } else { y });
109        }
110    }
111
112    #[test]
113    fn binary_multiplex_constant_bits() {
114        let mut rng = rand::thread_rng();
115        let nbits = 1 + (rng.r#gen::<usize>() % 200);
116        let x = rng.r#gen::<u128>() % (nbits as u128);
117        let y = rng.r#gen::<u128>() % (nbits as u128);
118
119        for b in 0..=1 {
120            let output = Dummy::eval(
121                &BinaryMultiplexConstantBits,
122                (DummyVal::new(b, 2), x, y, nbits),
123            )
124            .unwrap();
125            assert_eq!(DummyVal::from_binary(&output), if b == 0 { x } else { y });
126        }
127    }
128}