Skip to main content

fancy_garbling/circuits/binary/
binary_multiplication.rs

1use crate::{
2    BinaryBundle, FancyBinary,
3    circuit::{Circuit, CircuitInputMapper},
4    circuits::binary::{
5        BinaryAddition, BinaryAdditionNoCarry, BinaryConstant, BinaryLeftShift,
6        BinaryLeftShiftExtend,
7    },
8    util::u128_to_bits,
9};
10use core::marker::PhantomData;
11use swanky_channel::Channel;
12use swanky_error::Result;
13
14/// For [`BinaryBundle`] inputs `x` and `y`, output `x * y`.
15#[derive(Default)]
16pub struct BinaryMultiplication<'a>(PhantomData<&'a ()>);
17
18impl<'a> BinaryMultiplication<'a> {
19    /// Create a new [`BinaryMultiplication`] circuit.
20    pub fn new() -> Self {
21        Default::default()
22    }
23}
24
25impl<'a, F: FancyBinary> Circuit<F> for BinaryMultiplication<'a>
26where
27    F::Item: 'a,
28{
29    type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
30    type Output = BinaryBundle<F::Item>;
31
32    fn execute(
33        &self,
34        backend: &mut F,
35        inputs: Self::Input,
36        channel: &mut Channel,
37    ) -> Result<Self::Output> {
38        let (xs, ys) = inputs;
39        assert_eq!(xs.moduli(), ys.moduli());
40
41        let xwires = xs.wires();
42        let ywires = ys.wires();
43
44        let zero = backend.constant(0, 2, channel)?;
45
46        let mut sum = xwires
47            .iter()
48            .map(|x| backend.and(x, &ywires[0], channel))
49            .collect::<Result<_>>()
50            .map(BinaryBundle::new)?;
51
52        sum.pad(&zero, 1);
53
54        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
55            let mul = xwires
56                .iter()
57                .map(|x| backend.and(x, ywire, channel))
58                .collect::<Result<_>>()
59                .map(BinaryBundle::new)?;
60            let shifted = BinaryLeftShiftExtend::new().execute(backend, (&mul, i), channel)?;
61            let res = BinaryAddition::new().execute(backend, (&sum, &shifted), channel)?;
62            sum = res.0;
63            sum.push(res.1);
64        }
65
66        Ok(sum)
67    }
68}
69
70/// For [`BinaryBundle`]s `x` and `y`, return the the lower-order half of `x *
71/// y`.
72#[derive(Default)]
73pub struct BinaryMultiplicationLowerHalf<'a>(PhantomData<&'a ()>);
74
75impl<'a> BinaryMultiplicationLowerHalf<'a> {
76    /// Create a new [`BinaryMultiplicationLowerHalf`] circuit.
77    pub fn new() -> Self {
78        Default::default()
79    }
80}
81
82impl<'a, F: FancyBinary> Circuit<F> for BinaryMultiplicationLowerHalf<'a>
83where
84    F::Item: 'a,
85{
86    type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
87    type Output = BinaryBundle<F::Item>;
88
89    fn execute(
90        &self,
91        backend: &mut F,
92        inputs: Self::Input,
93        channel: &mut Channel,
94    ) -> Result<Self::Output> {
95        let (xs, ys) = inputs;
96        assert_eq!(xs.moduli(), ys.moduli());
97
98        let xwires = xs.wires();
99        let ywires = ys.wires();
100
101        let mut sum = xwires
102            .iter()
103            .map(|x| backend.and(x, &ywires[0], channel))
104            .collect::<Result<_>>()
105            .map(BinaryBundle::new)?;
106
107        for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
108            let mul = xwires
109                .iter()
110                .map(|x| backend.and(x, ywire, channel))
111                .collect::<Result<_>>()
112                .map(BinaryBundle::new)?;
113            let shifted = BinaryLeftShift::new().execute(backend, (&mul, i), channel)?;
114            sum = BinaryAdditionNoCarry::new().execute(backend, (&sum, &shifted), channel)?;
115        }
116        Ok(sum)
117    }
118}
119
120/// For [`BinaryBundle`] `x`, constant `c`, and bitlength `n`, output `x * c`,
121/// where the output is of bitlength `n`.
122#[derive(Default)]
123pub struct BinaryConstantMultiplication<'a>(PhantomData<&'a ()>);
124
125impl<'a> BinaryConstantMultiplication<'a> {
126    /// Create a new [`BinaryConstantMultiplication`] circuit.
127    pub fn new() -> Self {
128        Default::default()
129    }
130}
131
132impl<'a, F: FancyBinary> Circuit<F> for BinaryConstantMultiplication<'a>
133where
134    F::Item: 'a,
135{
136    type Input = (&'a BinaryBundle<F::Item>, u128, usize);
137    type Output = BinaryBundle<F::Item>;
138
139    fn execute(
140        &self,
141        backend: &mut F,
142        inputs: Self::Input,
143        channel: &mut Channel,
144    ) -> Result<Self::Output> {
145        let (x, c, nbits) = inputs;
146        let zero = BinaryConstant::new(0, nbits).execute(backend, (), channel)?;
147        u128_to_bits(c, nbits)
148            .into_iter()
149            .enumerate()
150            .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
151            .try_fold(zero, |z, shift_amt| {
152                let s = BinaryLeftShift::new().execute(backend, (x, shift_amt), channel)?;
153                BinaryAdditionNoCarry::new().execute(backend, (&z, &s), channel)
154            })
155    }
156}
157
158/// Circuit for testing [`BinaryMultiplication`].
159pub struct TestBinaryMultiplication<'a>(pub usize, PhantomData<&'a ()>);
160
161impl<'a> TestBinaryMultiplication<'a> {
162    /// Create a new [TestBinaryMultiplication] circuit.
163    pub fn new(nbits: usize) -> Self {
164        TestBinaryMultiplication(nbits, PhantomData)
165    }
166}
167
168impl<'a, F: FancyBinary> Circuit<F> for TestBinaryMultiplication<'a>
169where
170    F::Item: 'a,
171{
172    type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
173    type Output = BinaryBundle<F::Item>;
174
175    fn execute(
176        &self,
177        backend: &mut F,
178        inputs: Self::Input,
179        channel: &mut Channel,
180    ) -> Result<Self::Output> {
181        BinaryMultiplication::new().execute(backend, inputs, channel)
182    }
183}
184
185impl<'a, F: FancyBinary> CircuitInputMapper<F> for TestBinaryMultiplication<'a>
186where
187    F::Item: 'a,
188{
189    fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
190        assert_eq!(inputs.len(), self.0 * 2);
191        let (x, y) = inputs.split_at(self.0);
192        let x = BinaryBundle::new(x.to_vec());
193        let y = BinaryBundle::new(y.to_vec());
194        // Leak memory to create static references for the test
195        let x_ref: &'a BinaryBundle<F::Item> = Box::leak(Box::new(x));
196        let y_ref: &'a BinaryBundle<F::Item> = Box::leak(Box::new(y));
197        (x_ref, y_ref)
198    }
199
200    fn ninputs(&self) -> usize {
201        self.0 * 2
202    }
203
204    fn modulus(&self, _: usize) -> u16 {
205        2
206    }
207}
208
209#[cfg(test)]
210mod test {
211    use crate::{
212        circuits::binary::{
213            BinaryConstantMultiplication, BinaryMultiplication, BinaryMultiplicationLowerHalf,
214        },
215        dummy::{Dummy, DummyVal},
216    };
217    use rand::{Rng, thread_rng};
218
219    #[test]
220    fn binary_multiplication() {
221        let mut rng = thread_rng();
222        let nbits = 64;
223        let q = 1 << nbits;
224
225        for _ in 0..16 {
226            let x = rng.r#gen::<u128>() % q;
227            let y = rng.r#gen::<u128>() % q;
228            let x_input = DummyVal::to_binary(x, nbits);
229            let y_input = DummyVal::to_binary(y, nbits);
230            let output = Dummy::eval(&BinaryMultiplication::new(), (&x_input, &y_input)).unwrap();
231            assert_eq!(DummyVal::from_binary(&output), x * y);
232        }
233    }
234
235    #[test]
236    fn binary_multiplication_lower_half() {
237        let mut rng = thread_rng();
238        let nbits = 64;
239        let q = 1 << nbits;
240
241        for _ in 0..16 {
242            let x = rng.r#gen::<u128>() % q;
243            let y = rng.r#gen::<u128>() % q;
244            let x_input = DummyVal::to_binary(x, nbits);
245            let y_input = DummyVal::to_binary(y, nbits);
246            let output =
247                Dummy::eval(&BinaryMultiplicationLowerHalf::new(), (&x_input, &y_input)).unwrap();
248            assert_eq!(DummyVal::from_binary(&output), (x * y) % q);
249        }
250    }
251
252    #[test]
253    fn binary_constant_multiplication() {
254        let mut rng = thread_rng();
255        let nbits = 64;
256        let q = 1 << nbits;
257
258        for _ in 0..16 {
259            let x = rng.r#gen::<u128>() % q;
260            let c = rng.r#gen::<u128>() % q;
261            let x_input = DummyVal::to_binary(x, nbits);
262            let output =
263                Dummy::eval(&BinaryConstantMultiplication::new(), (&x_input, c, nbits)).unwrap();
264            assert_eq!(DummyVal::from_binary(&output), (x * c) % q);
265        }
266    }
267}