1use crate::{
5 Fancy, WireLabel,
6 circuit::EvaluableCircuit,
7 garble::{Evaluator, Garbler},
8 util::output_tweak,
9};
10use itertools::Itertools;
11use rand::{CryptoRng, RngCore};
12use std::collections::HashMap;
13use swanky_channel::Channel;
14use swanky_error::ErrorKind;
15use vectoreyes::U8x16;
16
17#[derive(Debug)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct GarbledCircuit {
24 blocks: Vec<U8x16>,
25}
26
27impl GarbledCircuit {
28 pub fn new(blocks: Vec<U8x16>) -> Self {
31 GarbledCircuit { blocks }
32 }
33
34 pub fn size(&self) -> usize {
36 self.blocks.len()
37 }
38
39 pub fn garble<
47 Wire: WireLabel,
48 Circuit: EvaluableCircuit<Garbler<RNG, Wire>>,
49 RNG: CryptoRng + RngCore,
50 >(
51 c: &Circuit,
52 rng: RNG,
53 ) -> swanky_error::Result<(Encoder<Wire>, Self, OutputMapping)> {
54 let mut channel = GarbledChannel::new_writer(None);
55 let mut garbler = Channel::with(&mut channel, |channel| Garbler::new(rng, channel))?;
56
57 let inputs = (0..c.num_inputs())
59 .map(|i| {
60 let q = c.input_mod(i);
61 let (zero, _) = garbler.encode_wire(0, q);
62 zero
63 })
64 .collect_vec();
65
66 let zeros = Channel::with(&mut channel, |channel| {
67 let zeros = c.eval_to_wirelabels(&mut garbler, &inputs, channel)?;
70 garbler.outputs(&zeros, channel)?;
75 Ok(zeros)
76 })?;
77
78 let deltas = garbler.get_deltas();
79 let en = Encoder::new(inputs, deltas.clone());
80 let gc = GarbledCircuit::new(channel.finish_writing());
81 let output_mapping = OutputMapping::new(&zeros, &deltas);
82
83 Ok((en, gc, output_mapping))
84 }
85
86 pub fn eval<Wire: WireLabel, Circuit: EvaluableCircuit<Evaluator<Wire>>>(
88 &self,
89 c: &Circuit,
90 inputs: &[Wire],
91 ) -> swanky_error::Result<Vec<u16>> {
92 let output = Channel::with(GarbledChannel::from(self), |channel| {
93 let mut evaluator = Evaluator::new(channel)?;
94 let outputs = c.eval(&mut evaluator, inputs, channel)?;
95 Ok(outputs.expect("evaluator outputs always are Some(u16)"))
96 })?;
97 Ok(output)
98 }
99
100 pub fn eval_to_wirelabels<Wire: WireLabel, Circuit: EvaluableCircuit<Evaluator<Wire>>>(
103 &self,
104 c: &Circuit,
105 inputs: &[Wire],
106 ) -> swanky_error::Result<Vec<Wire>> {
107 let wirelabels = Channel::with(GarbledChannel::from(self), |channel| {
108 let mut evaluator = Evaluator::new(channel)?;
109 let wirelabels = c.eval_to_wirelabels(&mut evaluator, inputs, channel)?;
110 Ok(wirelabels)
111 })?;
112 Ok(wirelabels)
113 }
114}
115
116#[derive(Debug)]
121#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
122pub struct Encoder<Wire> {
123 inputs: Vec<Wire>,
124 deltas: HashMap<u16, Wire>,
125}
126
127impl<Wire: WireLabel> Encoder<Wire> {
128 pub fn new(inputs: Vec<Wire>, deltas: HashMap<u16, Wire>) -> Self {
131 Encoder { inputs, deltas }
132 }
133
134 pub fn encode_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
140 assert_eq!(inputs.len(), self.inputs.len());
141 self.inputs
142 .iter()
143 .zip(inputs)
144 .map(|(zero, x)| {
145 let q = zero.modulus();
146 zero.clone() + self.deltas[&q].clone() * *x
147 })
148 .collect()
149 }
150}
151
152#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
154pub struct OutputMapping(Vec<Vec<U8x16>>);
155
156impl OutputMapping {
157 pub fn new<Wire: WireLabel>(zeros: &[Wire], deltas: &HashMap<u16, Wire>) -> Self {
160 let mut outputs = Vec::with_capacity(zeros.len());
161 for (i, zero) in zeros.iter().enumerate() {
162 let q = zero.modulus();
163 let mut wirelabels = Vec::with_capacity(q as usize);
164 for k in 0..q {
165 let wirelabel = zero.clone() + deltas[&q].clone() * k;
166 let hashed = wirelabel.hash(output_tweak(i, k));
167 wirelabels.push(hashed);
168 }
169 outputs.push(wirelabels);
170 }
171 Self(outputs)
172 }
173
174 pub fn to_outputs<Wire: WireLabel>(
180 &self,
181 wirelabels: &[Wire],
182 ) -> swanky_error::Result<Vec<u16>> {
183 let mut outputs = Vec::new();
184 for (i, wirelabel) in wirelabels.iter().enumerate() {
185 let q = wirelabel.modulus();
186 let mut decoded = None;
187 for k in 0..q {
188 let hashed = wirelabel.hash(output_tweak(i, k));
189 if hashed == self.0[i][k as usize] {
190 decoded = Some(k);
191 break;
192 }
193 }
194 if let Some(output) = decoded {
195 outputs.push(output);
196 } else {
197 swanky_error::bail!(ErrorKind::OtherError, "Decoding failed");
198 }
199 }
200 Ok(outputs)
201 }
202}
203
204pub struct GarbledChannel {
219 reader: Option<GarbledReader>,
220 writer: Option<GarbledWriter>,
221}
222
223impl GarbledChannel {
224 pub fn new_writer(ngates: Option<usize>) -> Self {
226 Self {
227 reader: None,
228 writer: Some(GarbledWriter::new(ngates)),
229 }
230 }
231
232 pub fn finish_writing(self) -> Vec<U8x16> {
237 self.writer.unwrap().finish()
238 }
239}
240
241impl std::io::Read for GarbledChannel {
242 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
243 let reader = self.reader.as_mut().unwrap();
244 reader.read(buf)
245 }
246}
247
248impl std::io::Write for GarbledChannel {
249 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
250 let writer = self.writer.as_mut().unwrap();
251 writer.write(buf)
252 }
253
254 fn flush(&mut self) -> std::io::Result<()> {
255 let writer = self.writer.as_mut().unwrap();
256 writer.flush()
257 }
258}
259
260impl From<&GarbledCircuit> for GarbledChannel {
261 fn from(value: &GarbledCircuit) -> Self {
262 Self {
263 reader: Some(GarbledReader::new(&value.blocks)),
264 writer: None,
265 }
266 }
267}
268
269#[derive(Debug)]
271struct GarbledReader {
272 blocks: Vec<U8x16>,
273 index: usize,
274}
275
276impl GarbledReader {
277 fn new(blocks: &[U8x16]) -> Self {
278 Self {
279 blocks: blocks.to_vec(),
280 index: 0,
281 }
282 }
283}
284
285impl std::io::Read for GarbledReader {
286 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
287 assert_eq!(buf.len() % 16, 0);
288 let start = self.index;
289 for data in buf.chunks_mut(16) {
290 let block: [u8; 16] = self.blocks[self.index].into();
291 for (a, b) in data.iter_mut().zip(block.iter()) {
292 *a = *b;
293 }
294 self.index += 1;
295 if self.index == self.blocks.len() {
296 return Ok(16 * (self.index - start));
300 }
301 }
302 Ok(buf.len())
303 }
304}
305
306#[derive(Debug)]
308struct GarbledWriter {
309 blocks: Vec<U8x16>,
310}
311
312impl GarbledWriter {
313 fn new(ngates: Option<usize>) -> Self {
315 let blocks = if let Some(n) = ngates {
316 Vec::with_capacity(2 * n)
317 } else {
318 Vec::new()
319 };
320 Self { blocks }
321 }
322
323 fn finish(self) -> Vec<U8x16> {
325 self.blocks
326 }
327}
328
329impl std::io::Write for GarbledWriter {
330 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
331 for item in buf.chunks(16) {
332 let bytes: [u8; 16] = match item.try_into() {
333 Ok(bytes) => bytes,
334 Err(_) => {
335 return Err(std::io::Error::new(
336 std::io::ErrorKind::InvalidData,
337 "unable to map bytes to block",
338 ));
339 }
340 };
341 self.blocks.push(bytes.into());
342 }
343 Ok(buf.len())
344 }
345 fn flush(&mut self) -> std::io::Result<()> {
346 Ok(())
347 }
348}