1use crate::{
2 AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, FancyInput, HasModulus, WireLabel,
3 WireMod2, check_binary,
4 fancy::{BinaryBundle, CrtBundle, Fancy, FancyReveal},
5 garble::binary_and::BinaryWireLabel,
6 hash_wires,
7 util::{RngExt, output_tweak, tweak, tweak2},
8};
9use rand::{CryptoRng, RngCore};
10#[cfg(feature = "serde")]
11use serde::de::DeserializeOwned;
12use std::collections::HashMap;
13use swanky_block::Block;
14use swanky_channel::Channel;
15
16use super::security_warning::warn_proj;
17
18pub struct Garbler<RNG, Wire> {
20 zero: Wire,
22 deltas: HashMap<u16, Wire>,
24 current_output: usize,
25 current_gate: usize,
26 rng: RNG,
27}
28
29#[cfg(feature = "serde")]
30impl<RNG: CryptoRng + RngCore, Wire: WireLabel + DeserializeOwned> Garbler<RNG, Wire> {
31 pub fn load_deltas(&mut self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
33 let f = std::fs::File::open(filename)?;
34 let reader = std::io::BufReader::new(f);
35 let deltas: HashMap<u16, Wire> = serde_json::from_reader(reader)?;
36 self.deltas.extend(deltas.into_iter());
37 Ok(())
38 }
39}
40
41impl<RNG: CryptoRng + RngCore, Wire: WireLabel> Garbler<RNG, Wire> {
42 pub fn new(mut rng: RNG, channel: &mut Channel) -> swanky_error::Result<Self> {
44 let zero = Wire::rand(&mut rng, 2);
45 let delta = Wire::rand_delta(&mut rng, 2);
46 let one = zero.plus(&delta);
47 let mut deltas = HashMap::new();
48 deltas.insert(2, delta);
49 channel.write(&one.to_block())?;
52 Ok(Garbler {
53 zero,
54 deltas,
55 current_gate: 0,
56 current_output: 0,
57 rng,
58 })
59 }
60
61 fn current_gate(&mut self) -> usize {
63 let current = self.current_gate;
64 self.current_gate += 1;
65 current
66 }
67
68 pub fn delta(&mut self, q: u16) -> Wire {
71 if let Some(delta) = self.deltas.get(&q) {
72 return delta.clone();
73 }
74 let w = Wire::rand_delta(&mut self.rng, q);
75 self.deltas.insert(q, w.clone());
76 w
77 }
78
79 fn current_output(&mut self) -> usize {
81 let current = self.current_output;
82 self.current_output += 1;
83 current
84 }
85
86 pub fn get_deltas(self) -> HashMap<u16, Wire> {
90 self.deltas
91 }
92
93 pub fn send_wire(&mut self, wire: &Wire, channel: &mut Channel) -> swanky_error::Result<()> {
95 channel.write(&wire.to_block())?;
96 Ok(())
97 }
98
99 pub fn encode_wire(&mut self, val: u16, modulus: u16) -> (Wire, Wire) {
101 let zero = Wire::rand(&mut self.rng, modulus);
102 let delta = self.delta(modulus);
103 let enc = zero.plus(&delta.cmul(val));
104 (zero, enc)
105 }
106
107 pub fn encode_many_wires(&mut self, vals: &[u16], moduli: &[u16]) -> (Vec<Wire>, Vec<Wire>) {
112 assert_eq!(vals.len(), moduli.len());
113
114 let mut gbs = Vec::with_capacity(vals.len());
115 let mut evs = Vec::with_capacity(vals.len());
116 for (x, q) in vals.iter().zip(moduli.iter()) {
117 let (gb, ev) = self.encode_wire(*x, *q);
118 gbs.push(gb);
119 evs.push(ev);
120 }
121 (gbs, evs)
122 }
123
124 pub fn crt_encode_wire(
126 &mut self,
127 val: u128,
128 modulus: u128,
129 ) -> (CrtBundle<Wire>, CrtBundle<Wire>) {
130 let ms = crate::util::factor(modulus);
131 let xs = crate::util::crt(val, &ms);
132 let (gbs, evs) = self.encode_many_wires(&xs, &ms);
133 (CrtBundle::new(gbs), CrtBundle::new(evs))
134 }
135
136 pub fn bin_encode_wire(
138 &mut self,
139 val: u128,
140 nbits: usize,
141 ) -> (BinaryBundle<Wire>, BinaryBundle<Wire>) {
142 let xs = crate::util::u128_to_bits(val, nbits);
143 let ms = vec![2; nbits];
144 let (gbs, evs) = self.encode_many_wires(&xs, &ms);
145 (BinaryBundle::new(gbs), BinaryBundle::new(evs))
146 }
147}
148
149impl<RNG: CryptoRng + RngCore, Wire: WireLabel> FancyInput for Garbler<RNG, Wire> {
150 type Item = Wire;
151
152 fn encode_many(
153 &mut self,
154 values: &[u16],
155 moduli: &[u16],
156 channel: &mut Channel,
157 ) -> swanky_error::Result<Vec<Self::Item>> {
158 let (zero, encoded) = self.encode_many_wires(values, moduli);
159 for wire in encoded {
160 channel.write(&wire.to_block())?;
161 }
162 Ok(zero)
163 }
164
165 fn receive_many(
166 &mut self,
167 _moduli: &[u16],
168 _: &mut Channel,
169 ) -> swanky_error::Result<Vec<Self::Item>> {
170 unimplemented!("Garbler cannot receive values")
171 }
172}
173
174impl<RNG: RngCore + CryptoRng, Wire: WireLabel> FancyReveal for Garbler<RNG, Wire> {
175 fn reveal(&mut self, x: &Wire, channel: &mut Channel) -> swanky_error::Result<u16> {
176 self.output(x, channel)?;
179 let val = channel.read::<u16>()?;
180 Ok(val)
181 }
182}
183
184impl<RNG: RngCore + CryptoRng, W: BinaryWireLabel> FancyBinary for Garbler<RNG, W> {
185 fn and(
186 &mut self,
187 A: &Self::Item,
188 B: &Self::Item,
189 channel: &mut Channel,
190 ) -> swanky_error::Result<Self::Item> {
191 let delta = self.delta(2);
192 let gate_num = self.current_gate();
193 let (gate0, gate1, C) = W::garble_and_gate(gate_num, A, B, &delta);
194 channel.write(&gate0)?;
195 channel.write(&gate1)?;
196 Ok(C)
197 }
198
199 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
200 x.plus(y)
201 }
202
203 fn negate(&mut self, x: &Self::Item) -> Self::Item {
208 let zero = self.zero.clone();
209 self.xor(&zero, x)
210 }
211}
212
213impl<RNG: RngCore + CryptoRng> FancyBinary for Garbler<RNG, AllWire> {
214 fn negate(&mut self, x: &Self::Item) -> Self::Item {
219 check_binary!(x);
220
221 let zero = self.zero.clone();
222 self.xor(&zero, x)
223 }
224
225 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
227 check_binary!(x);
228 check_binary!(y);
229
230 self.add(x, y)
231 }
232
233 fn and(
235 &mut self,
236 x: &Self::Item,
237 y: &Self::Item,
238 channel: &mut Channel,
239 ) -> swanky_error::Result<Self::Item> {
240 if let (AllWire::Mod2(A), AllWire::Mod2(B), AllWire::Mod2(ref delta)) =
241 (x, y, self.delta(2))
242 {
243 let gate_num = self.current_gate();
244 let (gate0, gate1, C) = WireMod2::garble_and_gate(gate_num, A, B, delta);
245 channel.write(&gate0)?;
246 channel.write(&gate1)?;
247 return Ok(AllWire::Mod2(C));
248 }
249 check_binary!(x);
251 check_binary!(y);
252
253 unreachable!()
255 }
256}
257
258impl<RNG: RngCore + CryptoRng, Wire: WireLabel + ArithmeticWire> FancyArithmetic
259 for Garbler<RNG, Wire>
260{
261 fn add(&mut self, x: &Wire, y: &Wire) -> Wire {
262 assert_eq!(x.modulus(), y.modulus());
263 x.plus(y)
264 }
265
266 fn sub(&mut self, x: &Wire, y: &Wire) -> Wire {
267 assert_eq!(x.modulus(), y.modulus());
268 x.minus(y)
269 }
270
271 fn cmul(&mut self, x: &Wire, c: u16) -> Wire {
272 x.cmul(c)
273 }
274
275 fn mul(&mut self, A: &Wire, B: &Wire, channel: &mut Channel) -> swanky_error::Result<Wire> {
276 if A.modulus() < B.modulus() {
277 return self.mul(B, A, channel);
278 }
279
280 let q = A.modulus();
281 let qb = B.modulus();
282 let gate_num = self.current_gate();
283
284 let D = self.delta(q);
285 let Db = self.delta(qb);
286
287 let r;
288 let mut gate = vec![Block::default(); q as usize + qb as usize - 2];
289
290 if q != qb {
292 assert!(
294 qb <= 8,
295 "`B.modulus()` with asymmetric moduli is capped at 8"
296 );
297
298 r = self.rng.gen_u16() % q;
299 let t = tweak2(gate_num as u64, 1);
300
301 let mut minitable = vec![u128::default(); qb as usize];
302 let mut B_ = B.clone();
303 for b in 0..qb {
304 if b > 0 {
305 B_.plus_eq(&Db);
306 }
307 let new_color = ((r + b) % q) as u128;
308 let ct = (u128::from(B_.hash(t)) & 0xFFFF) ^ new_color;
309 minitable[B_.color() as usize] = ct;
310 }
311
312 let mut packed = 0;
313 for i in 0..qb as usize {
314 packed += minitable[i] << (16 * i);
315 }
316 gate.push(Block::from(packed));
317 } else {
318 r = B.color(); }
320
321 let g = tweak2(gate_num as u64, 0);
322
323 let alpha = (q - A.color()) % q; let X1 = A.plus(&D.cmul(alpha));
326
327 let beta = (qb - B.color()) % qb;
329 let Y1 = B.plus(&Db.cmul(beta));
330
331 let [hashX, hashY] = hash_wires([&X1, &Y1], g);
332
333 let X = Wire::hash_to_mod(hashX, q).plus_mov(&D.cmul(alpha * r % q));
334 let Y = Wire::hash_to_mod(hashY, q).plus_mov(&A.cmul((beta + r) % q));
335
336 let mut precomp = Vec::with_capacity(q as usize);
337 let mut X_ = X.clone();
340 precomp.push(X_.to_block());
341 for _ in 1..q {
342 X_.plus_eq(&D);
343 precomp.push(X_.to_block());
344 }
345
346 let mut A_ = A.clone();
350 for a in 0..q {
351 if a > 0 {
352 A_.plus_eq(&D);
353 }
354 if A_.color() != 0 {
357 gate[A_.color() as usize - 1] =
358 A_.hash(g) ^ precomp[((q - (a * r % q)) % q) as usize];
359 }
360 }
361 precomp.clear();
362
363 let mut Y_ = Y.clone();
366 precomp.push(Y_.to_block());
367 for _ in 1..q {
368 Y_.plus_eq(A);
369 precomp.push(Y_.to_block());
370 }
371
372 let mut B_ = B.clone();
374 for b in 0..qb {
375 if b > 0 {
376 B_.plus_eq(&Db);
377 }
378 if B_.color() != 0 {
381 gate[q as usize - 1 + B_.color() as usize - 1] =
382 B_.hash(g) ^ precomp[((q - ((b + r) % q)) % q) as usize];
383 }
384 }
385
386 for block in gate.iter() {
387 channel.write(block)?;
388 }
389 Ok(X.plus_mov(&Y))
390 }
391
392 fn proj(
393 &mut self,
394 A: &Wire,
395 q_out: u16,
396 tt: Option<Vec<u16>>,
397 channel: &mut Channel,
398 ) -> swanky_error::Result<Wire> {
399 warn_proj();
400 assert!(tt.is_some(), "`tt` must not be `None`");
401 let tt = tt.unwrap();
402
403 let q_in = A.modulus();
404 let mut gate = vec![Block::default(); q_in as usize - 1];
405
406 let tao = A.color();
407 let g = tweak(self.current_gate());
408
409 let Din = self.delta(q_in);
410 let Dout = self.delta(q_out);
411
412 let C = A
415 .plus(&Din.cmul((q_in - tao) % q_in))
416 .hashback(g, q_out)
417 .plus_mov(&Dout.cmul((q_out - tt[((q_in - tao) % q_in) as usize]) % q_out));
418
419 let C_precomputed = {
421 let mut C_ = C.clone();
422 (0..q_out)
423 .map(|x| {
424 if x > 0 {
425 C_.plus_eq(&Dout);
426 }
427 C_.to_block()
428 })
429 .collect::<Vec<Block>>()
430 };
431
432 let mut A_ = A.clone();
433 for x in 0..q_in {
434 if x > 0 {
435 A_.plus_eq(&Din); }
437
438 let ix = (tao as usize + x as usize) % q_in as usize;
439 if ix == 0 {
440 continue;
441 }
442
443 let ct = A_.hash(g) ^ C_precomputed[tt[x as usize] as usize];
444 gate[ix - 1] = ct;
445 }
446
447 for block in gate.iter() {
448 channel.write(block)?;
449 }
450 Ok(C)
451 }
452}
453
454impl<RNG: RngCore + CryptoRng, Wire: WireLabel> Fancy for Garbler<RNG, Wire> {
455 type Item = Wire;
456
457 fn constant(&mut self, x: u16, q: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
458 let zero = Wire::rand(&mut self.rng, q);
459 let wire = zero.plus(self.delta(q).cmul_eq(x));
460 self.send_wire(&wire, channel)?;
461 Ok(zero)
462 }
463
464 fn output(&mut self, X: &Wire, channel: &mut Channel) -> swanky_error::Result<Option<u16>> {
465 let q = X.modulus();
466 let i = self.current_output();
467 let D = self.delta(q);
468 for k in 0..q {
469 let block = X.plus(&D.cmul(k)).hash(output_tweak(i, k));
470 channel.write(&block)?;
471 }
472 Ok(None)
473 }
474}