1use crate::{
5 Fancy, WireLabel,
6 circuit::{Circuit, CircuitInputMapper, Flatten},
7 garble::{Evaluator, Garbler},
8 util::output_tweak,
9};
10use rand::{CryptoRng, RngCore};
11use std::collections::HashMap;
12use swanky_channel::Channel;
13use swanky_error::ErrorKind;
14use vectoreyes::U8x16;
15
16#[derive(Debug)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct GarbledCircuit {
23 blocks: Vec<U8x16>,
24}
25
26impl GarbledCircuit {
27 pub fn new(blocks: Vec<U8x16>) -> Self {
30 GarbledCircuit { blocks }
31 }
32
33 pub fn size(&self) -> usize {
35 self.blocks.len()
36 }
37
38 pub fn garble<
46 Wire: WireLabel,
47 C: CircuitInputMapper<Garbler<RNG, Wire>>,
48 RNG: CryptoRng + RngCore,
49 >(
50 circuit: &C,
51 rng: RNG,
52 ) -> swanky_error::Result<(Encoder<Wire>, Self, OutputMapping)> {
53 let mut channel = GarbledChannel::new_writer(None);
54 let (en, output_mapping) = Channel::with(&mut channel, |channel| {
55 let mut garbler = Garbler::new(rng, channel)?;
56
57 let inputs = (0..circuit.ninputs())
59 .map(|i| {
60 let q = circuit.modulus(i);
61 garbler.encode_zero(q)
62 })
63 .collect::<Vec<_>>();
64
65 let zeros = circuit.execute(&mut garbler, &circuit.map(inputs.clone()), channel)?;
68 let zeros = zeros.flatten();
69 garbler.outputs(&zeros, channel)?;
74
75 let deltas = garbler.get_deltas();
76 let en = Encoder::new(inputs, deltas.clone());
77 let output_mapping = OutputMapping::new(&zeros, &deltas);
78
79 Ok((en, output_mapping))
80 })?;
81 let gc = GarbledCircuit::new(channel.finish_writing());
82 Ok((en, gc, output_mapping))
83 }
84
85 pub fn eval<Wire: WireLabel, C: Circuit<Evaluator<Wire>>>(
88 &self,
89 circuit: &C,
90 inputs: &C::Input,
91 output_mapping: &OutputMapping,
92 ) -> swanky_error::Result<Vec<u16>> {
93 let wirelabels = self.eval_to_wirelabels(circuit, inputs)?;
94 output_mapping.to_outputs(&wirelabels.flatten())
95 }
96
97 pub fn eval_to_wirelabels<Wire: WireLabel, C: Circuit<Evaluator<Wire>>>(
100 &self,
101 circuit: &C,
102 inputs: &C::Input,
103 ) -> swanky_error::Result<C::Output> {
104 Channel::with(GarbledChannel::from(self), |channel| {
105 let mut evaluator = Evaluator::new(channel)?;
106 circuit.execute(&mut evaluator, inputs, channel)
107 })
108 }
109}
110
111#[derive(Debug)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct Encoder<Wire> {
115 inputs: Vec<Wire>,
116 deltas: HashMap<u16, Wire>,
117}
118
119impl<Wire: WireLabel> Encoder<Wire> {
120 pub fn new(inputs: Vec<Wire>, deltas: HashMap<u16, Wire>) -> Self {
123 Encoder { inputs, deltas }
124 }
125
126 pub fn encode_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
132 assert_eq!(inputs.len(), self.inputs.len());
133 self.inputs
134 .iter()
135 .zip(inputs)
136 .map(|(zero, x)| {
137 let q = zero.modulus();
138 zero.clone() + self.deltas[&q].clone() * *x
139 })
140 .collect()
141 }
142}
143
144#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
146pub struct OutputMapping(Vec<Vec<U8x16>>);
147
148impl OutputMapping {
149 pub fn new<Wire: WireLabel>(zeros: &[Wire], deltas: &HashMap<u16, Wire>) -> Self {
152 let mut outputs = Vec::with_capacity(zeros.len());
153 for (i, zero) in zeros.iter().enumerate() {
154 let q = zero.modulus();
155 let mut wirelabels = Vec::with_capacity(q as usize);
156 for k in 0..q {
157 let wirelabel = zero.clone() + deltas[&q].clone() * k;
158 let hashed = wirelabel.hash(output_tweak(i, k));
159 wirelabels.push(hashed);
160 }
161 outputs.push(wirelabels);
162 }
163 Self(outputs)
164 }
165
166 pub fn to_outputs<Wire: WireLabel>(
172 &self,
173 wirelabels: &[Wire],
174 ) -> swanky_error::Result<Vec<u16>> {
175 let mut outputs = Vec::new();
176 for (i, wirelabel) in wirelabels.iter().enumerate() {
177 let q = wirelabel.modulus();
178 let mut decoded = None;
179 for k in 0..q {
180 let hashed = wirelabel.hash(output_tweak(i, k));
181 if hashed == self.0[i][k as usize] {
182 decoded = Some(k);
183 break;
184 }
185 }
186 if let Some(output) = decoded {
187 outputs.push(output);
188 } else {
189 swanky_error::bail!(ErrorKind::OtherError, "Decoding failed");
190 }
191 }
192 Ok(outputs)
193 }
194}
195
196pub struct GarbledChannel {
211 reader: Option<GarbledReader>,
212 writer: Option<GarbledWriter>,
213}
214
215impl GarbledChannel {
216 pub fn new_writer(ngates: Option<usize>) -> Self {
218 Self {
219 reader: None,
220 writer: Some(GarbledWriter::new(ngates)),
221 }
222 }
223
224 pub fn finish_writing(self) -> Vec<U8x16> {
229 self.writer.unwrap().finish()
230 }
231}
232
233impl std::io::Read for GarbledChannel {
234 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
235 let reader = self.reader.as_mut().unwrap();
236 reader.read(buf)
237 }
238}
239
240impl std::io::Write for GarbledChannel {
241 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
242 let writer = self.writer.as_mut().unwrap();
243 writer.write(buf)
244 }
245
246 fn flush(&mut self) -> std::io::Result<()> {
247 let writer = self.writer.as_mut().unwrap();
248 writer.flush()
249 }
250}
251
252impl From<&GarbledCircuit> for GarbledChannel {
253 fn from(value: &GarbledCircuit) -> Self {
254 Self {
255 reader: Some(GarbledReader::new(&value.blocks)),
256 writer: None,
257 }
258 }
259}
260
261#[derive(Debug)]
263struct GarbledReader {
264 blocks: Vec<U8x16>,
265 index: usize,
266}
267
268impl GarbledReader {
269 fn new(blocks: &[U8x16]) -> Self {
270 Self {
271 blocks: blocks.to_vec(),
272 index: 0,
273 }
274 }
275}
276
277impl std::io::Read for GarbledReader {
278 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
279 assert_eq!(buf.len() % 16, 0);
280 let start = self.index;
281 for data in buf.chunks_mut(16) {
282 let block: [u8; 16] = self.blocks[self.index].into();
283 for (a, b) in data.iter_mut().zip(block.iter()) {
284 *a = *b;
285 }
286 self.index += 1;
287 if self.index == self.blocks.len() {
288 return Ok(16 * (self.index - start));
292 }
293 }
294 Ok(buf.len())
295 }
296}
297
298#[derive(Debug)]
300struct GarbledWriter {
301 blocks: Vec<U8x16>,
302}
303
304impl GarbledWriter {
305 fn new(ngates: Option<usize>) -> Self {
307 let blocks = if let Some(n) = ngates {
308 Vec::with_capacity(2 * n)
309 } else {
310 Vec::new()
311 };
312 Self { blocks }
313 }
314
315 fn finish(self) -> Vec<U8x16> {
317 self.blocks
318 }
319}
320
321impl std::io::Write for GarbledWriter {
322 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
323 for item in buf.chunks(16) {
324 let bytes: [u8; 16] = match item.try_into() {
325 Ok(bytes) => bytes,
326 Err(_) => {
327 return Err(std::io::Error::new(
328 std::io::ErrorKind::InvalidData,
329 "unable to map bytes to block",
330 ));
331 }
332 };
333 self.blocks.push(bytes.into());
334 }
335 Ok(buf.len())
336 }
337 fn flush(&mut self) -> std::io::Result<()> {
338 Ok(())
339 }
340}