1use crate::{
2 AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, FancyProj, HasModulus, WireLabel,
3 WireMod2, check_binary,
4 fancy::{BinaryBundle, Fancy},
5 garble::binary_and::BinaryWireLabel,
6 hash_wires,
7 util::{RngExt, output_tweak, tweak, tweak2},
8};
9use rand::{CryptoRng, RngCore};
10#[cfg(feature = "serde")]
11use serde::de::DeserializeOwned;
12use std::collections::HashMap;
13use swanky_channel::Channel;
14
15use super::security_warning::warn_proj;
16
17pub struct Garbler<RNG, Wire> {
19 zero: Wire,
21 deltas: HashMap<u16, Wire>,
23 current_output: usize,
24 current_gate: usize,
25 rng: RNG,
26}
27
28#[cfg(feature = "serde")]
29impl<RNG: CryptoRng + RngCore, Wire: WireLabel + DeserializeOwned> Garbler<RNG, Wire> {
30 pub fn load_deltas(&mut self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
32 let f = std::fs::File::open(filename)?;
33 let reader = std::io::BufReader::new(f);
34 let deltas: HashMap<u16, Wire> = serde_json::from_reader(reader)?;
35 self.deltas.extend(deltas);
36 Ok(())
37 }
38}
39
40impl<RNG: CryptoRng + RngCore, Wire: WireLabel> Garbler<RNG, Wire> {
41 pub fn new(mut rng: RNG, channel: &mut Channel) -> swanky_error::Result<Self> {
43 let zero = Wire::rand(&mut rng, 2);
44 let delta = Wire::rand_delta(&mut rng, 2);
45 let one = zero.clone() + delta.clone();
46 let mut deltas = HashMap::new();
47 deltas.insert(2, delta);
48 channel.write(&one.to_repr())?;
51 Ok(Garbler {
52 zero,
53 deltas,
54 current_gate: 0,
55 current_output: 0,
56 rng,
57 })
58 }
59
60 fn current_gate(&mut self) -> usize {
62 let current = self.current_gate;
63 self.current_gate += 1;
64 current
65 }
66
67 pub fn delta(&mut self, q: u16) -> Wire {
70 if let Some(delta) = self.deltas.get(&q) {
71 return delta.clone();
72 }
73 let w = Wire::rand_delta(&mut self.rng, q);
74 self.deltas.insert(q, w.clone());
75 w
76 }
77
78 fn current_output(&mut self) -> usize {
80 let current = self.current_output;
81 self.current_output += 1;
82 current
83 }
84
85 pub fn get_deltas(self) -> HashMap<u16, Wire> {
89 self.deltas
90 }
91
92 pub fn encode_zero(&mut self, modulus: u16) -> Wire {
94 Wire::rand(&mut self.rng, modulus)
95 }
96
97 pub fn bin_encode_zero(&mut self, nbits: usize) -> BinaryBundle<Wire> {
99 let zeros = (0..nbits).map(|_| self.encode_zero(2)).collect::<Vec<_>>();
100 BinaryBundle::new(zeros)
101 }
102}
103
104impl<RNG: RngCore + CryptoRng, W: BinaryWireLabel> FancyBinary for Garbler<RNG, W> {
105 fn and(
106 &mut self,
107 A: &Self::Item,
108 B: &Self::Item,
109 channel: &mut Channel,
110 ) -> swanky_error::Result<Self::Item> {
111 let delta = self.delta(2);
112 let gate_num = self.current_gate();
113 let (gate0, gate1, C) = W::garble_and_gate(gate_num, A, B, &delta);
114 channel.write(&gate0)?;
115 channel.write(&gate1)?;
116 Ok(C)
117 }
118
119 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
120 *x + *y
121 }
122
123 fn negate(&mut self, x: &Self::Item) -> Self::Item {
128 let zero = self.zero;
129 self.xor(&zero, x)
130 }
131}
132
133impl<RNG: RngCore + CryptoRng> FancyBinary for Garbler<RNG, AllWire> {
134 fn negate(&mut self, x: &Self::Item) -> Self::Item {
139 check_binary!(x);
140
141 let zero = self.zero.clone();
142 self.xor(&zero, x)
143 }
144
145 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
147 check_binary!(x);
148 check_binary!(y);
149
150 self.add(x, y)
151 }
152
153 fn and(
155 &mut self,
156 x: &Self::Item,
157 y: &Self::Item,
158 channel: &mut Channel,
159 ) -> swanky_error::Result<Self::Item> {
160 if let (AllWire::Mod2(A), AllWire::Mod2(B), AllWire::Mod2(ref delta)) =
161 (x, y, self.delta(2))
162 {
163 let gate_num = self.current_gate();
164 let (gate0, gate1, C) = WireMod2::garble_and_gate(gate_num, A, B, delta);
165 channel.write(&gate0)?;
166 channel.write(&gate1)?;
167 return Ok(AllWire::Mod2(C));
168 }
169 check_binary!(x);
171 check_binary!(y);
172
173 unreachable!()
175 }
176}
177
178impl<RNG: RngCore + CryptoRng, Wire: WireLabel + ArithmeticWire> FancyArithmetic
179 for Garbler<RNG, Wire>
180{
181 fn add(&mut self, x: &Wire, y: &Wire) -> Wire {
182 assert_eq!(x.modulus(), y.modulus());
183 x.clone() + y.clone()
184 }
185
186 fn sub(&mut self, x: &Wire, y: &Wire) -> Wire {
187 assert_eq!(x.modulus(), y.modulus());
188 x.clone() - y.clone()
189 }
190
191 fn cmul(&mut self, x: &Wire, c: u16) -> Wire {
192 x.clone() * c
193 }
194
195 fn mul(&mut self, A: &Wire, B: &Wire, channel: &mut Channel) -> swanky_error::Result<Wire> {
196 if A.modulus() < B.modulus() {
197 return self.mul(B, A, channel);
198 }
199
200 let q = A.modulus();
201 let qb = B.modulus();
202 let gate_num = self.current_gate();
203
204 let D = self.delta(q);
205 let Db = self.delta(qb);
206
207 let r;
208 let mut gate = vec![Default::default(); q as usize + qb as usize - 2];
209
210 if q != qb {
212 assert!(
214 qb <= 8,
215 "`B.modulus()` with asymmetric moduli is capped at 8"
216 );
217
218 r = self.rng.gen_u16() % q;
219 let t = tweak2(gate_num as u64, 1);
220
221 let mut minitable = vec![u128::default(); qb as usize];
222 let mut B_ = B.clone();
223 for b in 0..qb {
224 if b > 0 {
225 B_ += Db.clone();
226 }
227 let new_color = ((r + b) % q) as u128;
228 let ct = (u128::from(B_.hash(t)) & 0xFFFF) ^ new_color;
229 minitable[B_.color() as usize] = ct;
230 }
231
232 let mut packed = 0;
233 for (i, item) in minitable.iter().enumerate().take(qb as usize) {
234 packed += item << (16 * i);
235 }
236 gate.push(packed.into());
237 } else {
238 r = B.color(); }
240
241 let g = tweak2(gate_num as u64, 0);
242
243 let alpha = (q - A.color()) % q; let X1 = A.clone() + D.clone() * alpha;
246
247 let beta = (qb - B.color()) % qb;
249 let Y1 = B.clone() + Db.clone() * beta;
250
251 let [hashX, hashY] = hash_wires([&X1, &Y1], g);
252
253 let X = Wire::hash_to_mod(hashX, q) + D.clone() * (alpha * r % q);
254 let Y = Wire::hash_to_mod(hashY, q) + A.clone() * ((beta + r) % q);
255
256 let mut precomp = Vec::with_capacity(q as usize);
257 let mut X_ = X.clone();
260 precomp.push(X_.to_repr());
261 for _ in 1..q {
262 X_ += D.clone();
263 precomp.push(X_.to_repr());
264 }
265
266 let mut A_ = A.clone();
270 for a in 0..q {
271 if a > 0 {
272 A_ += D.clone();
273 }
274 if A_.color() != 0 {
277 gate[A_.color() as usize - 1] =
278 A_.hash(g) ^ precomp[((q - (a * r % q)) % q) as usize];
279 }
280 }
281 precomp.clear();
282
283 let mut Y_ = Y.clone();
286 precomp.push(Y_.to_repr());
287 for _ in 1..q {
288 Y_ += A.clone();
289 precomp.push(Y_.to_repr());
290 }
291
292 let mut B_ = B.clone();
294 for b in 0..qb {
295 if b > 0 {
296 B_ += Db.clone();
297 }
298 if B_.color() != 0 {
301 gate[q as usize - 1 + B_.color() as usize - 1] =
302 B_.hash(g) ^ precomp[((q - ((b + r) % q)) % q) as usize];
303 }
304 }
305
306 for block in gate.iter() {
307 channel.write(block)?;
308 }
309 Ok(X + Y)
310 }
311}
312
313impl<RNG: RngCore + CryptoRng, Wire: WireLabel + ArithmeticWire> FancyProj for Garbler<RNG, Wire> {
314 fn proj(
315 &mut self,
316 A: &Wire,
317 q_out: u16,
318 tt: Option<Vec<u16>>,
319 channel: &mut Channel,
320 ) -> swanky_error::Result<Wire> {
321 warn_proj();
322 assert!(tt.is_some(), "`tt` must not be `None`");
323 let tt = tt.unwrap();
324
325 let q_in = A.modulus();
326 let mut gate = vec![Default::default(); q_in as usize - 1];
327
328 let tao = A.color();
329 let g = tweak(self.current_gate());
330
331 let Din = self.delta(q_in);
332 let Dout = self.delta(q_out);
333
334 let C = (A.clone() + Din.clone() * ((q_in - tao) % q_in)).hashback(g, q_out)
337 + Dout.clone() * ((q_out - tt[((q_in - tao) % q_in) as usize]) % q_out);
338
339 let C_precomputed = {
341 let mut C_ = C.clone();
342 (0..q_out)
343 .map(|x| {
344 if x > 0 {
345 C_ += Dout.clone();
346 }
347 C_.to_repr()
348 })
349 .collect::<Vec<_>>()
350 };
351
352 let mut A_ = A.clone();
353 for x in 0..q_in {
354 if x > 0 {
355 A_ += Din.clone(); }
357
358 let ix = (tao as usize + x as usize) % q_in as usize;
359 if ix == 0 {
360 continue;
361 }
362
363 let ct = A_.hash(g) ^ C_precomputed[tt[x as usize] as usize];
364 gate[ix - 1] = ct;
365 }
366
367 for block in gate.iter() {
368 channel.write(block)?;
369 }
370 Ok(C)
371 }
372}
373
374impl<RNG: RngCore + CryptoRng, Wire: WireLabel> Fancy for Garbler<RNG, Wire> {
375 type Item = Wire;
376
377 fn encode_many(
378 &mut self,
379 values: &[u16],
380 moduli: &[u16],
381 channel: &mut Channel,
382 ) -> swanky_error::Result<Vec<Self::Item>> {
383 assert_eq!(values.len(), moduli.len());
384
385 let mut zeros = Vec::with_capacity(values.len());
386 for (x, q) in values.iter().zip(moduli.iter()) {
387 let delta = self.delta(*q);
388 let zero = self.encode_zero(*q);
389 let encoded = zero.clone() + delta * *x;
390 channel.write(&encoded.to_repr())?;
391 zeros.push(zero);
392 }
393 Ok(zeros)
394 }
395
396 fn receive_many(
397 &mut self,
398 _moduli: &[u16],
399 _: &mut Channel,
400 ) -> swanky_error::Result<Vec<Self::Item>> {
401 unimplemented!("Garbler cannot receive values")
402 }
403
404 fn constant(&mut self, x: u16, q: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
405 let (zero, wire) = Wire::constant(x, q, &self.delta(q), &mut self.rng);
406 channel.write(&wire.to_repr())?;
407 Ok(zero)
408 }
409
410 fn output(&mut self, X: &Wire, channel: &mut Channel) -> swanky_error::Result<Option<u16>> {
411 let q = X.modulus();
412 let i = self.current_output();
413 let D = self.delta(q);
414 for k in 0..q {
415 let block = (X.clone() + D.clone() * k).hash(output_tweak(i, k));
416 channel.write(&block)?;
417 }
418 Ok(None)
419 }
420}