1use crate::{
5 Fancy, WireLabel,
6 circuit::EvaluableCircuit,
7 garble::{Evaluator, Garbler},
8 util::output_tweak,
9};
10use itertools::Itertools;
11use rand::{CryptoRng, RngCore};
12use std::collections::HashMap;
13use swanky_channel::Channel;
14use swanky_error::ErrorKind;
15use vectoreyes::U8x16;
16
17#[derive(Debug)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct GarbledCircuit {
24 blocks: Vec<U8x16>,
25}
26
27impl GarbledCircuit {
28 pub fn new(blocks: Vec<U8x16>) -> Self {
31 GarbledCircuit { blocks }
32 }
33
34 pub fn size(&self) -> usize {
36 self.blocks.len()
37 }
38
39 pub fn garble<
47 Wire: WireLabel,
48 Circuit: EvaluableCircuit<Garbler<RNG, Wire>>,
49 RNG: CryptoRng + RngCore,
50 >(
51 c: &Circuit,
52 rng: RNG,
53 ) -> swanky_error::Result<(Encoder<Wire>, Self, OutputMapping)> {
54 let mut channel = GarbledChannel::new_writer(None);
55 let mut garbler = Channel::with(&mut channel, |channel| Garbler::new(rng, channel))?;
56
57 let gb_inps = (0..c.num_garbler_inputs())
59 .map(|i| {
60 let q = c.garbler_input_mod(i);
61 let (zero, _) = garbler.encode_wire(0, q);
62 zero
63 })
64 .collect_vec();
65
66 let ev_inps = (0..c.num_evaluator_inputs())
67 .map(|i| {
68 let q = c.evaluator_input_mod(i);
69 let (zero, _) = garbler.encode_wire(0, q);
70 zero
71 })
72 .collect_vec();
73
74 let zeros = Channel::with(&mut channel, |channel| {
75 let zeros = c.eval_to_wirelabels(&mut garbler, &gb_inps, &ev_inps, channel)?;
78 garbler.outputs(&zeros, channel)?;
83 Ok(zeros)
84 })?;
85
86 let deltas = garbler.get_deltas();
87 let en = Encoder::new(gb_inps, ev_inps, deltas.clone());
88 let gc = GarbledCircuit::new(channel.finish_writing());
89 let output_mapping = OutputMapping::new(&zeros, &deltas);
90
91 Ok((en, gc, output_mapping))
92 }
93
94 pub fn eval<Wire: WireLabel, Circuit: EvaluableCircuit<Evaluator<Wire>>>(
96 &self,
97 c: &Circuit,
98 garbler_inputs: &[Wire],
99 evaluator_inputs: &[Wire],
100 ) -> swanky_error::Result<Vec<u16>> {
101 let output = Channel::with(GarbledChannel::from(self), |channel| {
102 let mut evaluator = Evaluator::new(channel)?;
103 let outputs = c.eval(&mut evaluator, garbler_inputs, evaluator_inputs, channel)?;
104 Ok(outputs.expect("evaluator outputs always are Some(u16)"))
105 })?;
106 Ok(output)
107 }
108
109 pub fn eval_to_wirelabels<Wire: WireLabel, Circuit: EvaluableCircuit<Evaluator<Wire>>>(
112 &self,
113 c: &Circuit,
114 garbler_inputs: &[Wire],
115 evaluator_inputs: &[Wire],
116 ) -> swanky_error::Result<Vec<Wire>> {
117 let wirelabels = Channel::with(GarbledChannel::from(self), |channel| {
118 let mut evaluator = Evaluator::new(channel)?;
119 let wirelabels =
120 c.eval_to_wirelabels(&mut evaluator, garbler_inputs, evaluator_inputs, channel)?;
121 Ok(wirelabels)
122 })?;
123 Ok(wirelabels)
124 }
125}
126
127#[derive(Debug)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133pub struct Encoder<Wire> {
134 garbler_inputs: Vec<Wire>,
135 evaluator_inputs: Vec<Wire>,
136 deltas: HashMap<u16, Wire>,
137}
138
139impl<Wire: WireLabel> Encoder<Wire> {
140 pub fn new(
143 garbler_inputs: Vec<Wire>,
144 evaluator_inputs: Vec<Wire>,
145 deltas: HashMap<u16, Wire>,
146 ) -> Self {
147 Encoder {
148 garbler_inputs,
149 evaluator_inputs,
150 deltas,
151 }
152 }
153
154 pub fn encode_garbler_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
160 assert_eq!(inputs.len(), self.garbler_inputs.len());
161 self.garbler_inputs
162 .iter()
163 .zip(inputs)
164 .map(|(zero, x)| {
165 let q = zero.modulus();
166 zero.clone() + self.deltas[&q].clone() * *x
167 })
168 .collect()
169 }
170
171 pub fn encode_evaluator_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
177 assert_eq!(inputs.len(), self.evaluator_inputs.len());
178 self.evaluator_inputs
179 .iter()
180 .zip(inputs)
181 .map(|(zero, x)| {
182 let q = zero.modulus();
183 zero.clone() + self.deltas[&q].clone() * *x
184 })
185 .collect()
186 }
187}
188
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
191pub struct OutputMapping(Vec<Vec<U8x16>>);
192
193impl OutputMapping {
194 pub fn new<Wire: WireLabel>(zeros: &[Wire], deltas: &HashMap<u16, Wire>) -> Self {
197 let mut outputs = Vec::with_capacity(zeros.len());
198 for (i, zero) in zeros.iter().enumerate() {
199 let q = zero.modulus();
200 let mut wirelabels = Vec::with_capacity(q as usize);
201 for k in 0..q {
202 let wirelabel = zero.clone() + deltas[&q].clone() * k;
203 let hashed = wirelabel.hash(output_tweak(i, k));
204 wirelabels.push(hashed);
205 }
206 outputs.push(wirelabels);
207 }
208 Self(outputs)
209 }
210
211 pub fn to_outputs<Wire: WireLabel>(
217 &self,
218 wirelabels: &[Wire],
219 ) -> swanky_error::Result<Vec<u16>> {
220 let mut outputs = Vec::new();
221 for (i, wirelabel) in wirelabels.iter().enumerate() {
222 let q = wirelabel.modulus();
223 let mut decoded = None;
224 for k in 0..q {
225 let hashed = wirelabel.hash(output_tweak(i, k));
226 if hashed == self.0[i][k as usize] {
227 decoded = Some(k);
228 break;
229 }
230 }
231 if let Some(output) = decoded {
232 outputs.push(output);
233 } else {
234 swanky_error::bail!(ErrorKind::OtherError, "Decoding failed");
235 }
236 }
237 Ok(outputs)
238 }
239}
240
241pub struct GarbledChannel {
256 reader: Option<GarbledReader>,
257 writer: Option<GarbledWriter>,
258}
259
260impl GarbledChannel {
261 pub fn new_writer(ngates: Option<usize>) -> Self {
263 Self {
264 reader: None,
265 writer: Some(GarbledWriter::new(ngates)),
266 }
267 }
268
269 pub fn finish_writing(self) -> Vec<U8x16> {
274 self.writer.unwrap().finish()
275 }
276}
277
278impl std::io::Read for GarbledChannel {
279 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
280 let reader = self.reader.as_mut().unwrap();
281 reader.read(buf)
282 }
283}
284
285impl std::io::Write for GarbledChannel {
286 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
287 let writer = self.writer.as_mut().unwrap();
288 writer.write(buf)
289 }
290
291 fn flush(&mut self) -> std::io::Result<()> {
292 let writer = self.writer.as_mut().unwrap();
293 writer.flush()
294 }
295}
296
297impl From<&GarbledCircuit> for GarbledChannel {
298 fn from(value: &GarbledCircuit) -> Self {
299 Self {
300 reader: Some(GarbledReader::new(&value.blocks)),
301 writer: None,
302 }
303 }
304}
305
306#[derive(Debug)]
308struct GarbledReader {
309 blocks: Vec<U8x16>,
310 index: usize,
311}
312
313impl GarbledReader {
314 fn new(blocks: &[U8x16]) -> Self {
315 Self {
316 blocks: blocks.to_vec(),
317 index: 0,
318 }
319 }
320}
321
322impl std::io::Read for GarbledReader {
323 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
324 assert_eq!(buf.len() % 16, 0);
325 let start = self.index;
326 for data in buf.chunks_mut(16) {
327 let block: [u8; 16] = self.blocks[self.index].into();
328 for (a, b) in data.iter_mut().zip(block.iter()) {
329 *a = *b;
330 }
331 self.index += 1;
332 if self.index == self.blocks.len() {
333 return Ok(16 * (self.index - start));
337 }
338 }
339 Ok(buf.len())
340 }
341}
342
343#[derive(Debug)]
345struct GarbledWriter {
346 blocks: Vec<U8x16>,
347}
348
349impl GarbledWriter {
350 fn new(ngates: Option<usize>) -> Self {
352 let blocks = if let Some(n) = ngates {
353 Vec::with_capacity(2 * n)
354 } else {
355 Vec::new()
356 };
357 Self { blocks }
358 }
359
360 fn finish(self) -> Vec<U8x16> {
362 self.blocks
363 }
364}
365
366impl std::io::Write for GarbledWriter {
367 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
368 for item in buf.chunks(16) {
369 let bytes: [u8; 16] = match item.try_into() {
370 Ok(bytes) => bytes,
371 Err(_) => {
372 return Err(std::io::Error::new(
373 std::io::ErrorKind::InvalidData,
374 "unable to map bytes to block",
375 ));
376 }
377 };
378 self.blocks.push(bytes.into());
379 }
380 Ok(buf.len())
381 }
382 fn flush(&mut self) -> std::io::Result<()> {
383 Ok(())
384 }
385}