Skip to main content

fancy_garbling/circuits/binary/
binary_max.rs

1use crate::{BinaryBundle, FancyBinary, circuit::Circuit, circuits::binary::BinaryLessThan};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// Binary max.
7///
8/// For a vector of [`BinaryBundle`]s, return the max value.
9///
10/// # Panics
11/// This panics if the input vector is empty.
12#[derive(Default)]
13pub struct BinaryMax<'a>(PhantomData<&'a ()>);
14
15impl<'a> BinaryMax<'a> {
16    /// Create a new [`BinaryMax`] circuit.
17    pub fn new() -> Self {
18        Default::default()
19    }
20}
21
22impl<'a, F: FancyBinary> Circuit<F> for BinaryMax<'a>
23where
24    F::Item: 'a,
25{
26    type Input = &'a [BinaryBundle<F::Item>];
27    type Output = BinaryBundle<F::Item>;
28
29    fn execute(
30        &self,
31        backend: &mut F,
32        inputs: Self::Input,
33        channel: &mut Channel,
34    ) -> Result<Self::Output> {
35        let xs = inputs;
36        assert!(!xs.is_empty(), "`xs` cannot be empty");
37        xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
38            // Compute `x < y`.
39            let pos = BinaryLessThan::new().execute(backend, (&x, y), channel)?;
40            // Compute `!(x < y)`.
41            let neg = backend.negate(&pos);
42            // Compute `x * (x >= y) ^ y * (x < y)`.
43            Ok(BinaryBundle::new(
44                x.wires()
45                    .iter()
46                    .zip(y.wires().iter())
47                    .map(|(x, y)| {
48                        let xp = backend.and(x, &neg, channel)?;
49                        let yp = backend.and(y, &pos, channel)?;
50                        Ok(backend.xor(&xp, &yp))
51                    })
52                    .collect::<Result<Vec<F::Item>>>()?,
53            ))
54        })
55    }
56}
57
58#[cfg(test)]
59mod test {
60    use rand::Rng;
61
62    use super::BinaryMax;
63    use crate::{
64        BinaryBundle,
65        dummy::{Dummy, DummyVal},
66    };
67
68    #[test]
69    fn binary_max() {
70        let mut rng = rand::thread_rng();
71        let nbits = 64;
72        let q = 1 << nbits;
73        let nitems = 10;
74
75        for _ in 0..16 {
76            let xs: Vec<u128> = (0..nitems).map(|_| rng.r#gen::<u128>() % q).collect();
77            let max = *xs.iter().max().unwrap();
78            let xs_input: Vec<BinaryBundle<DummyVal>> =
79                xs.iter().map(|x| DummyVal::to_binary(*x, nbits)).collect();
80            let output = Dummy::eval(&BinaryMax::new(), xs_input.as_slice()).unwrap();
81            let output_val = DummyVal::from_binary(&output);
82            assert_eq!(output_val, max);
83        }
84    }
85}