1use crate::{
2 AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, FancyInput, HasModulus, WireLabel,
3 WireMod2, check_binary,
4 fancy::{BinaryBundle, CrtBundle, Fancy, FancyReveal},
5 garble::binary_and::BinaryWireLabel,
6 hash_wires,
7 util::{RngExt, output_tweak, tweak, tweak2},
8};
9use rand::{CryptoRng, RngCore};
10#[cfg(feature = "serde")]
11use serde::de::DeserializeOwned;
12use std::collections::HashMap;
13use swanky_channel::Channel;
14
15use super::security_warning::warn_proj;
16
17pub struct Garbler<RNG, Wire> {
19 zero: Wire,
21 deltas: HashMap<u16, Wire>,
23 current_output: usize,
24 current_gate: usize,
25 rng: RNG,
26}
27
28#[cfg(feature = "serde")]
29impl<RNG: CryptoRng + RngCore, Wire: WireLabel + DeserializeOwned> Garbler<RNG, Wire> {
30 pub fn load_deltas(&mut self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
32 let f = std::fs::File::open(filename)?;
33 let reader = std::io::BufReader::new(f);
34 let deltas: HashMap<u16, Wire> = serde_json::from_reader(reader)?;
35 self.deltas.extend(deltas);
36 Ok(())
37 }
38}
39
40impl<RNG: CryptoRng + RngCore, Wire: WireLabel> Garbler<RNG, Wire> {
41 pub fn new(mut rng: RNG, channel: &mut Channel) -> swanky_error::Result<Self> {
43 let zero = Wire::rand(&mut rng, 2);
44 let delta = Wire::rand_delta(&mut rng, 2);
45 let one = zero.clone() + delta.clone();
46 let mut deltas = HashMap::new();
47 deltas.insert(2, delta);
48 channel.write(&one.to_repr())?;
51 Ok(Garbler {
52 zero,
53 deltas,
54 current_gate: 0,
55 current_output: 0,
56 rng,
57 })
58 }
59
60 fn current_gate(&mut self) -> usize {
62 let current = self.current_gate;
63 self.current_gate += 1;
64 current
65 }
66
67 pub fn delta(&mut self, q: u16) -> Wire {
70 if let Some(delta) = self.deltas.get(&q) {
71 return delta.clone();
72 }
73 let w = Wire::rand_delta(&mut self.rng, q);
74 self.deltas.insert(q, w.clone());
75 w
76 }
77
78 fn current_output(&mut self) -> usize {
80 let current = self.current_output;
81 self.current_output += 1;
82 current
83 }
84
85 pub fn get_deltas(self) -> HashMap<u16, Wire> {
89 self.deltas
90 }
91
92 pub fn send_wire(&mut self, wire: &Wire, channel: &mut Channel) -> swanky_error::Result<()> {
94 channel.write(&wire.to_repr())?;
95 Ok(())
96 }
97
98 pub fn encode_wire(&mut self, val: u16, modulus: u16) -> (Wire, Wire) {
100 let zero = Wire::rand(&mut self.rng, modulus);
101 let delta = self.delta(modulus);
102 let enc = zero.clone() + delta * val;
103 (zero, enc)
104 }
105
106 pub fn encode_many_wires(&mut self, vals: &[u16], moduli: &[u16]) -> (Vec<Wire>, Vec<Wire>) {
111 assert_eq!(vals.len(), moduli.len());
112
113 let mut gbs = Vec::with_capacity(vals.len());
114 let mut evs = Vec::with_capacity(vals.len());
115 for (x, q) in vals.iter().zip(moduli.iter()) {
116 let (gb, ev) = self.encode_wire(*x, *q);
117 gbs.push(gb);
118 evs.push(ev);
119 }
120 (gbs, evs)
121 }
122
123 pub fn crt_encode_wire(
125 &mut self,
126 val: u128,
127 modulus: u128,
128 ) -> (CrtBundle<Wire>, CrtBundle<Wire>) {
129 let ms = crate::util::factor(modulus);
130 let xs = crate::util::crt(val, &ms);
131 let (gbs, evs) = self.encode_many_wires(&xs, &ms);
132 (CrtBundle::new(gbs), CrtBundle::new(evs))
133 }
134
135 pub fn bin_encode_wire(
137 &mut self,
138 val: u128,
139 nbits: usize,
140 ) -> (BinaryBundle<Wire>, BinaryBundle<Wire>) {
141 let xs = crate::util::u128_to_bits(val, nbits);
142 let ms = vec![2; nbits];
143 let (gbs, evs) = self.encode_many_wires(&xs, &ms);
144 (BinaryBundle::new(gbs), BinaryBundle::new(evs))
145 }
146}
147
148impl<RNG: CryptoRng + RngCore, Wire: WireLabel> FancyInput for Garbler<RNG, Wire> {
149 type Item = Wire;
150
151 fn encode_many(
152 &mut self,
153 values: &[u16],
154 moduli: &[u16],
155 channel: &mut Channel,
156 ) -> swanky_error::Result<Vec<Self::Item>> {
157 let (zero, encoded) = self.encode_many_wires(values, moduli);
158 for wire in encoded {
159 channel.write(&wire.to_repr())?;
160 }
161 Ok(zero)
162 }
163
164 fn receive_many(
165 &mut self,
166 _moduli: &[u16],
167 _: &mut Channel,
168 ) -> swanky_error::Result<Vec<Self::Item>> {
169 unimplemented!("Garbler cannot receive values")
170 }
171}
172
173impl<RNG: RngCore + CryptoRng, Wire: WireLabel> FancyReveal for Garbler<RNG, Wire> {
174 fn reveal(&mut self, x: &Wire, channel: &mut Channel) -> swanky_error::Result<u16> {
175 self.output(x, channel)?;
178 let val = channel.read::<u16>()?;
179 Ok(val)
180 }
181}
182
183impl<RNG: RngCore + CryptoRng, W: BinaryWireLabel> FancyBinary for Garbler<RNG, W> {
184 fn and(
185 &mut self,
186 A: &Self::Item,
187 B: &Self::Item,
188 channel: &mut Channel,
189 ) -> swanky_error::Result<Self::Item> {
190 let delta = self.delta(2);
191 let gate_num = self.current_gate();
192 let (gate0, gate1, C) = W::garble_and_gate(gate_num, A, B, &delta);
193 channel.write(&gate0)?;
194 channel.write(&gate1)?;
195 Ok(C)
196 }
197
198 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
199 *x + *y
200 }
201
202 fn negate(&mut self, x: &Self::Item) -> Self::Item {
207 let zero = self.zero;
208 self.xor(&zero, x)
209 }
210}
211
212impl<RNG: RngCore + CryptoRng> FancyBinary for Garbler<RNG, AllWire> {
213 fn negate(&mut self, x: &Self::Item) -> Self::Item {
218 check_binary!(x);
219
220 let zero = self.zero.clone();
221 self.xor(&zero, x)
222 }
223
224 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
226 check_binary!(x);
227 check_binary!(y);
228
229 self.add(x, y)
230 }
231
232 fn and(
234 &mut self,
235 x: &Self::Item,
236 y: &Self::Item,
237 channel: &mut Channel,
238 ) -> swanky_error::Result<Self::Item> {
239 if let (AllWire::Mod2(A), AllWire::Mod2(B), AllWire::Mod2(ref delta)) =
240 (x, y, self.delta(2))
241 {
242 let gate_num = self.current_gate();
243 let (gate0, gate1, C) = WireMod2::garble_and_gate(gate_num, A, B, delta);
244 channel.write(&gate0)?;
245 channel.write(&gate1)?;
246 return Ok(AllWire::Mod2(C));
247 }
248 check_binary!(x);
250 check_binary!(y);
251
252 unreachable!()
254 }
255}
256
257impl<RNG: RngCore + CryptoRng, Wire: WireLabel + ArithmeticWire> FancyArithmetic
258 for Garbler<RNG, Wire>
259{
260 fn add(&mut self, x: &Wire, y: &Wire) -> Wire {
261 assert_eq!(x.modulus(), y.modulus());
262 x.clone() + y.clone()
263 }
264
265 fn sub(&mut self, x: &Wire, y: &Wire) -> Wire {
266 assert_eq!(x.modulus(), y.modulus());
267 x.clone() - y.clone()
268 }
269
270 fn cmul(&mut self, x: &Wire, c: u16) -> Wire {
271 x.clone() * c
272 }
273
274 fn mul(&mut self, A: &Wire, B: &Wire, channel: &mut Channel) -> swanky_error::Result<Wire> {
275 if A.modulus() < B.modulus() {
276 return self.mul(B, A, channel);
277 }
278
279 let q = A.modulus();
280 let qb = B.modulus();
281 let gate_num = self.current_gate();
282
283 let D = self.delta(q);
284 let Db = self.delta(qb);
285
286 let r;
287 let mut gate = vec![Default::default(); q as usize + qb as usize - 2];
288
289 if q != qb {
291 assert!(
293 qb <= 8,
294 "`B.modulus()` with asymmetric moduli is capped at 8"
295 );
296
297 r = self.rng.gen_u16() % q;
298 let t = tweak2(gate_num as u64, 1);
299
300 let mut minitable = vec![u128::default(); qb as usize];
301 let mut B_ = B.clone();
302 for b in 0..qb {
303 if b > 0 {
304 B_ += Db.clone();
305 }
306 let new_color = ((r + b) % q) as u128;
307 let ct = (u128::from(B_.hash(t)) & 0xFFFF) ^ new_color;
308 minitable[B_.color() as usize] = ct;
309 }
310
311 let mut packed = 0;
312 for (i, item) in minitable.iter().enumerate().take(qb as usize) {
313 packed += item << (16 * i);
314 }
315 gate.push(packed.into());
316 } else {
317 r = B.color(); }
319
320 let g = tweak2(gate_num as u64, 0);
321
322 let alpha = (q - A.color()) % q; let X1 = A.clone() + D.clone() * alpha;
325
326 let beta = (qb - B.color()) % qb;
328 let Y1 = B.clone() + Db.clone() * beta;
329
330 let [hashX, hashY] = hash_wires([&X1, &Y1], g);
331
332 let X = Wire::hash_to_mod(hashX, q) + D.clone() * (alpha * r % q);
333 let Y = Wire::hash_to_mod(hashY, q) + A.clone() * ((beta + r) % q);
334
335 let mut precomp = Vec::with_capacity(q as usize);
336 let mut X_ = X.clone();
339 precomp.push(X_.to_repr());
340 for _ in 1..q {
341 X_ += D.clone();
342 precomp.push(X_.to_repr());
343 }
344
345 let mut A_ = A.clone();
349 for a in 0..q {
350 if a > 0 {
351 A_ += D.clone();
352 }
353 if A_.color() != 0 {
356 gate[A_.color() as usize - 1] =
357 A_.hash(g) ^ precomp[((q - (a * r % q)) % q) as usize];
358 }
359 }
360 precomp.clear();
361
362 let mut Y_ = Y.clone();
365 precomp.push(Y_.to_repr());
366 for _ in 1..q {
367 Y_ += A.clone();
368 precomp.push(Y_.to_repr());
369 }
370
371 let mut B_ = B.clone();
373 for b in 0..qb {
374 if b > 0 {
375 B_ += Db.clone();
376 }
377 if B_.color() != 0 {
380 gate[q as usize - 1 + B_.color() as usize - 1] =
381 B_.hash(g) ^ precomp[((q - ((b + r) % q)) % q) as usize];
382 }
383 }
384
385 for block in gate.iter() {
386 channel.write(block)?;
387 }
388 Ok(X + Y)
389 }
390
391 fn proj(
392 &mut self,
393 A: &Wire,
394 q_out: u16,
395 tt: Option<Vec<u16>>,
396 channel: &mut Channel,
397 ) -> swanky_error::Result<Wire> {
398 warn_proj();
399 assert!(tt.is_some(), "`tt` must not be `None`");
400 let tt = tt.unwrap();
401
402 let q_in = A.modulus();
403 let mut gate = vec![Default::default(); q_in as usize - 1];
404
405 let tao = A.color();
406 let g = tweak(self.current_gate());
407
408 let Din = self.delta(q_in);
409 let Dout = self.delta(q_out);
410
411 let C = (A.clone() + Din.clone() * ((q_in - tao) % q_in)).hashback(g, q_out)
414 + Dout.clone() * ((q_out - tt[((q_in - tao) % q_in) as usize]) % q_out);
415
416 let C_precomputed = {
418 let mut C_ = C.clone();
419 (0..q_out)
420 .map(|x| {
421 if x > 0 {
422 C_ += Dout.clone();
423 }
424 C_.to_repr()
425 })
426 .collect::<Vec<_>>()
427 };
428
429 let mut A_ = A.clone();
430 for x in 0..q_in {
431 if x > 0 {
432 A_ += Din.clone(); }
434
435 let ix = (tao as usize + x as usize) % q_in as usize;
436 if ix == 0 {
437 continue;
438 }
439
440 let ct = A_.hash(g) ^ C_precomputed[tt[x as usize] as usize];
441 gate[ix - 1] = ct;
442 }
443
444 for block in gate.iter() {
445 channel.write(block)?;
446 }
447 Ok(C)
448 }
449}
450
451impl<RNG: RngCore + CryptoRng, Wire: WireLabel> Fancy for Garbler<RNG, Wire> {
452 type Item = Wire;
453
454 fn constant(&mut self, x: u16, q: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
455 let zero = Wire::rand(&mut self.rng, q);
456 let wire = zero.clone() + self.delta(q) * x;
457 self.send_wire(&wire, channel)?;
458 Ok(zero)
459 }
460
461 fn output(&mut self, X: &Wire, channel: &mut Channel) -> swanky_error::Result<Option<u16>> {
462 let q = X.modulus();
463 let i = self.current_output();
464 let D = self.delta(q);
465 for k in 0..q {
466 let block = (X.clone() + D.clone() * k).hash(output_tweak(i, k));
467 channel.write(&block)?;
468 }
469 Ok(None)
470 }
471}