Skip to main content

fancy_garbling/circuits/binary/
mux.rs

1use crate::{FancyBinary, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// For input `(b, x, y)` return `x` if `b == 0`, otherwise return `y`.
7#[derive(Default)]
8pub struct Mux<'a>(PhantomData<&'a ()>);
9
10impl<'a> Mux<'a> {
11    /// Create a new [`Mux`] circuit.
12    pub fn new() -> Self {
13        Default::default()
14    }
15}
16
17impl<'a, F: FancyBinary> Circuit<F> for Mux<'a>
18where
19    F::Item: 'a,
20{
21    type Input = (&'a F::Item, &'a F::Item, &'a F::Item);
22    type Output = F::Item;
23
24    fn execute(
25        &self,
26        backend: &mut F,
27        inputs: Self::Input,
28        channel: &mut Channel,
29    ) -> Result<Self::Output> {
30        // The mux can be computed as `b (x ^ y) ^ x`.
31        let (b, x, y) = inputs;
32        let xor = backend.xor(x, y);
33        let and = backend.and(b, &xor, channel)?;
34        Ok(backend.xor(&and, x))
35    }
36}
37
38/// For input `(b, c1, c2)`, return `c1` if `b == 0`, otherwise return `c2`.
39#[derive(Default)]
40pub struct MuxConstants<'a>(PhantomData<&'a ()>);
41
42impl<'a> MuxConstants<'a> {
43    /// Create a new [MuxConstants] circuit.
44    pub fn new() -> Self {
45        Default::default()
46    }
47}
48
49impl<'a, F: FancyBinary> Circuit<F> for MuxConstants<'a>
50where
51    F::Item: 'a,
52{
53    type Input = (&'a F::Item, bool, bool);
54    type Output = F::Item;
55
56    fn execute(
57        &self,
58        backend: &mut F,
59        inputs: Self::Input,
60        channel: &mut Channel,
61    ) -> Result<Self::Output> {
62        let (b, c1, c2) = inputs;
63        match (c1, c2) {
64            (false, true) => Ok(b.clone()),
65            (true, false) => Ok(backend.negate(b)),
66            (false, false) => backend.constant(0, 2, channel),
67            (true, true) => backend.constant(1, 2, channel),
68        }
69    }
70}
71
72#[cfg(test)]
73mod test {
74    use super::Mux;
75    use crate::{
76        circuits::binary::mux::MuxConstants,
77        dummy::{Dummy, DummyVal},
78    };
79
80    #[test]
81    fn mux() {
82        for b in 0..=1 {
83            for x in 0..=1 {
84                for y in 0..=1 {
85                    let b_val = DummyVal::new(b, 2);
86                    let x_val = DummyVal::new(x, 2);
87                    let y_val = DummyVal::new(y, 2);
88                    let output = Dummy::eval(&Mux::new(), (&b_val, &x_val, &y_val)).unwrap();
89                    assert_eq!(output.val(), if b == 0 { x } else { y });
90                }
91            }
92        }
93    }
94
95    #[test]
96    fn mux_constants() {
97        for b in 0..=1 {
98            for x in 0..=1 {
99                for y in 0..=1 {
100                    let b_val = DummyVal::new(b, 2);
101                    let output =
102                        Dummy::eval(&MuxConstants::new(), (&b_val, x != 0, y != 0)).unwrap();
103                    assert_eq!(output.val(), if b == 0 { x } else { y });
104                }
105            }
106        }
107    }
108}