fancy_garbling/circuits/binary/
binary_multiplication.rs1use crate::{
2 BinaryBundle, FancyBinary,
3 circuit::{Circuit, CircuitInputMapper},
4 circuits::binary::{
5 BinaryAddition, BinaryAdditionNoCarry, BinaryConstant, BinaryLeftShift,
6 BinaryLeftShiftExtend,
7 },
8 util::u128_to_bits,
9};
10use core::marker::PhantomData;
11use swanky_channel::Channel;
12use swanky_error::Result;
13
14#[derive(Default)]
16pub struct BinaryMultiplication<'a>(PhantomData<&'a ()>);
17
18impl<'a> BinaryMultiplication<'a> {
19 pub fn new() -> Self {
21 Default::default()
22 }
23}
24
25impl<'a, F: FancyBinary> Circuit<F> for BinaryMultiplication<'a>
26where
27 F::Item: 'a,
28{
29 type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
30 type Output = BinaryBundle<F::Item>;
31
32 fn execute(
33 &self,
34 backend: &mut F,
35 inputs: Self::Input,
36 channel: &mut Channel,
37 ) -> Result<Self::Output> {
38 let (xs, ys) = inputs;
39 assert_eq!(xs.moduli(), ys.moduli());
40
41 let xwires = xs.wires();
42 let ywires = ys.wires();
43
44 let zero = backend.constant(0, 2, channel)?;
45
46 let mut sum = xwires
47 .iter()
48 .map(|x| backend.and(x, &ywires[0], channel))
49 .collect::<Result<_>>()
50 .map(BinaryBundle::new)?;
51
52 sum.pad(&zero, 1);
53
54 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
55 let mul = xwires
56 .iter()
57 .map(|x| backend.and(x, ywire, channel))
58 .collect::<Result<_>>()
59 .map(BinaryBundle::new)?;
60 let shifted = BinaryLeftShiftExtend::new().execute(backend, (&mul, i), channel)?;
61 let res = BinaryAddition::new().execute(backend, (&sum, &shifted), channel)?;
62 sum = res.0;
63 sum.push(res.1);
64 }
65
66 Ok(sum)
67 }
68}
69
70#[derive(Default)]
73pub struct BinaryMultiplicationLowerHalf<'a>(PhantomData<&'a ()>);
74
75impl<'a> BinaryMultiplicationLowerHalf<'a> {
76 pub fn new() -> Self {
78 Default::default()
79 }
80}
81
82impl<'a, F: FancyBinary> Circuit<F> for BinaryMultiplicationLowerHalf<'a>
83where
84 F::Item: 'a,
85{
86 type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
87 type Output = BinaryBundle<F::Item>;
88
89 fn execute(
90 &self,
91 backend: &mut F,
92 inputs: Self::Input,
93 channel: &mut Channel,
94 ) -> Result<Self::Output> {
95 let (xs, ys) = inputs;
96 assert_eq!(xs.moduli(), ys.moduli());
97
98 let xwires = xs.wires();
99 let ywires = ys.wires();
100
101 let mut sum = xwires
102 .iter()
103 .map(|x| backend.and(x, &ywires[0], channel))
104 .collect::<Result<_>>()
105 .map(BinaryBundle::new)?;
106
107 for (i, ywire) in ywires.iter().enumerate().take(xwires.len()).skip(1) {
108 let mul = xwires
109 .iter()
110 .map(|x| backend.and(x, ywire, channel))
111 .collect::<Result<_>>()
112 .map(BinaryBundle::new)?;
113 let shifted = BinaryLeftShift::new().execute(backend, (&mul, i), channel)?;
114 sum = BinaryAdditionNoCarry::new().execute(backend, (&sum, &shifted), channel)?;
115 }
116 Ok(sum)
117 }
118}
119
120#[derive(Default)]
123pub struct BinaryConstantMultiplication<'a>(PhantomData<&'a ()>);
124
125impl<'a> BinaryConstantMultiplication<'a> {
126 pub fn new() -> Self {
128 Default::default()
129 }
130}
131
132impl<'a, F: FancyBinary> Circuit<F> for BinaryConstantMultiplication<'a>
133where
134 F::Item: 'a,
135{
136 type Input = (&'a BinaryBundle<F::Item>, u128, usize);
137 type Output = BinaryBundle<F::Item>;
138
139 fn execute(
140 &self,
141 backend: &mut F,
142 inputs: Self::Input,
143 channel: &mut Channel,
144 ) -> Result<Self::Output> {
145 let (x, c, nbits) = inputs;
146 let zero = BinaryConstant::new(0, nbits).execute(backend, (), channel)?;
147 u128_to_bits(c, nbits)
148 .into_iter()
149 .enumerate()
150 .filter_map(|(i, b)| if b > 0 { Some(i) } else { None })
151 .try_fold(zero, |z, shift_amt| {
152 let s = BinaryLeftShift::new().execute(backend, (x, shift_amt), channel)?;
153 BinaryAdditionNoCarry::new().execute(backend, (&z, &s), channel)
154 })
155 }
156}
157
158pub struct TestBinaryMultiplication<'a>(pub usize, PhantomData<&'a ()>);
160
161impl<'a> TestBinaryMultiplication<'a> {
162 pub fn new(nbits: usize) -> Self {
164 TestBinaryMultiplication(nbits, PhantomData)
165 }
166}
167
168impl<'a, F: FancyBinary> Circuit<F> for TestBinaryMultiplication<'a>
169where
170 F::Item: 'a,
171{
172 type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
173 type Output = BinaryBundle<F::Item>;
174
175 fn execute(
176 &self,
177 backend: &mut F,
178 inputs: Self::Input,
179 channel: &mut Channel,
180 ) -> Result<Self::Output> {
181 BinaryMultiplication::new().execute(backend, inputs, channel)
182 }
183}
184
185impl<'a, F: FancyBinary> CircuitInputMapper<F> for TestBinaryMultiplication<'a>
186where
187 F::Item: 'a,
188{
189 fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
190 assert_eq!(inputs.len(), self.0 * 2);
191 let (x, y) = inputs.split_at(self.0);
192 let x = BinaryBundle::new(x.to_vec());
193 let y = BinaryBundle::new(y.to_vec());
194 let x_ref: &'a BinaryBundle<F::Item> = Box::leak(Box::new(x));
196 let y_ref: &'a BinaryBundle<F::Item> = Box::leak(Box::new(y));
197 (x_ref, y_ref)
198 }
199
200 fn ninputs(&self) -> usize {
201 self.0 * 2
202 }
203
204 fn modulus(&self, _: usize) -> u16 {
205 2
206 }
207}
208
209#[cfg(test)]
210mod test {
211 use crate::{
212 circuits::binary::{
213 BinaryConstantMultiplication, BinaryMultiplication, BinaryMultiplicationLowerHalf,
214 },
215 dummy::{Dummy, DummyVal},
216 };
217 use rand::{Rng, thread_rng};
218
219 #[test]
220 fn binary_multiplication() {
221 let mut rng = thread_rng();
222 let nbits = 64;
223 let q = 1 << nbits;
224
225 for _ in 0..16 {
226 let x = rng.r#gen::<u128>() % q;
227 let y = rng.r#gen::<u128>() % q;
228 let x_input = DummyVal::to_binary(x, nbits);
229 let y_input = DummyVal::to_binary(y, nbits);
230 let output = Dummy::eval(&BinaryMultiplication::new(), (&x_input, &y_input)).unwrap();
231 assert_eq!(DummyVal::from_binary(&output), x * y);
232 }
233 }
234
235 #[test]
236 fn binary_multiplication_lower_half() {
237 let mut rng = thread_rng();
238 let nbits = 64;
239 let q = 1 << nbits;
240
241 for _ in 0..16 {
242 let x = rng.r#gen::<u128>() % q;
243 let y = rng.r#gen::<u128>() % q;
244 let x_input = DummyVal::to_binary(x, nbits);
245 let y_input = DummyVal::to_binary(y, nbits);
246 let output =
247 Dummy::eval(&BinaryMultiplicationLowerHalf::new(), (&x_input, &y_input)).unwrap();
248 assert_eq!(DummyVal::from_binary(&output), (x * y) % q);
249 }
250 }
251
252 #[test]
253 fn binary_constant_multiplication() {
254 let mut rng = thread_rng();
255 let nbits = 64;
256 let q = 1 << nbits;
257
258 for _ in 0..16 {
259 let x = rng.r#gen::<u128>() % q;
260 let c = rng.r#gen::<u128>() % q;
261 let x_input = DummyVal::to_binary(x, nbits);
262 let output =
263 Dummy::eval(&BinaryConstantMultiplication::new(), (&x_input, c, nbits)).unwrap();
264 assert_eq!(DummyVal::from_binary(&output), (x * c) % q);
265 }
266 }
267}