1use crate::{
5 WireLabel,
6 circuit::EvaluableCircuit,
7 errors::{EvaluatorError, GarblerError},
8 garble::{Evaluator, Garbler},
9};
10use itertools::Itertools;
11use std::{collections::HashMap, marker::PhantomData, rc::Rc};
12use swanky_aes_rng::AesRng;
13use swanky_block::Block;
14use swanky_channel_legacy::Channel;
15
16#[derive(Debug)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct GarbledCircuit<W, C> {
22 blocks: Vec<Block>,
23 _phantom_wire: PhantomData<W>,
24 _phantom_circ: PhantomData<C>,
25}
26
27impl<W, C> GarbledCircuit<W, C> {
28 pub fn new(blocks: Vec<Block>) -> Self {
30 GarbledCircuit {
31 blocks,
32 _phantom_wire: PhantomData,
33 _phantom_circ: PhantomData,
34 }
35 }
36
37 pub fn size(&self) -> usize {
39 self.blocks.len()
40 }
41}
42
43type Ev<Wire> = Evaluator<Channel<GarbledReader, GarbledWriter>, Wire>;
44type Gb<Wire> = Garbler<Channel<GarbledReader, GarbledWriter>, AesRng, Wire>;
45
46impl<Wire: WireLabel, Circuit: EvaluableCircuit<Ev<Wire>>> GarbledCircuit<Wire, Circuit> {
47 pub fn eval(
49 &self,
50 c: &Circuit,
51 garbler_inputs: &[Wire],
52 evaluator_inputs: &[Wire],
53 ) -> Result<Vec<u16>, EvaluatorError> {
54 let channel = Channel::new(GarbledReader::new(&self.blocks), GarbledWriter::new(None));
55 let mut evaluator = Evaluator::new(channel);
56 let outputs = c.eval(&mut evaluator, garbler_inputs, evaluator_inputs)?;
57 Ok(outputs.expect("evaluator outputs always are Some(u16)"))
58 }
59}
60
61pub fn garble<Wire: WireLabel, Circuit: EvaluableCircuit<Gb<Wire>>>(
63 c: &Circuit,
64) -> Result<(Encoder<Wire>, GarbledCircuit<Wire, Circuit>), GarblerError> {
65 let channel = Channel::new(
66 GarbledReader::new(&[]),
67 GarbledWriter::new(Some(c.get_num_nonfree_gates())),
68 );
69 let channel_ = channel.clone();
70
71 let rng = AesRng::new();
72 let mut garbler = Garbler::new(channel_, rng);
73
74 let gb_inps = (0..c.num_garbler_inputs())
76 .map(|i| {
77 let q = c.garbler_input_mod(i);
78 let (zero, _) = garbler.encode_wire(0, q);
79 zero
80 })
81 .collect_vec();
82
83 let ev_inps = (0..c.num_evaluator_inputs())
84 .map(|i| {
85 let q = c.evaluator_input_mod(i);
86 let (zero, _) = garbler.encode_wire(0, q);
87 zero
88 })
89 .collect_vec();
90
91 c.eval(&mut garbler, &gb_inps, &ev_inps)?;
92
93 let en = Encoder::new(gb_inps, ev_inps, garbler.get_deltas());
94
95 let gc = GarbledCircuit::new(
96 Rc::try_unwrap(channel.writer())
97 .unwrap()
98 .into_inner()
99 .blocks,
100 );
101
102 Ok((en, gc))
103}
104
105#[derive(Debug)]
110#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
111pub struct Encoder<Wire> {
112 garbler_inputs: Vec<Wire>,
113 evaluator_inputs: Vec<Wire>,
114 deltas: HashMap<u16, Wire>,
115}
116
117impl<Wire: WireLabel> Encoder<Wire> {
118 pub fn new(
121 garbler_inputs: Vec<Wire>,
122 evaluator_inputs: Vec<Wire>,
123 deltas: HashMap<u16, Wire>,
124 ) -> Self {
125 Encoder {
126 garbler_inputs,
127 evaluator_inputs,
128 deltas,
129 }
130 }
131
132 pub fn num_garbler_inputs(&self) -> usize {
134 self.garbler_inputs.len()
135 }
136
137 pub fn num_evaluator_inputs(&self) -> usize {
139 self.evaluator_inputs.len()
140 }
141
142 pub fn encode_garbler_input(&self, x: u16, id: usize) -> Wire {
144 let X = &self.garbler_inputs[id];
145 let q = X.modulus();
146 X.plus(&self.deltas[&q].cmul(x))
147 }
148
149 pub fn encode_evaluator_input(&self, x: u16, id: usize) -> Wire {
151 let X = &self.evaluator_inputs[id];
152 let q = X.modulus();
153 X.plus(&self.deltas[&q].cmul(x))
154 }
155
156 pub fn encode_garbler_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
158 debug_assert_eq!(inputs.len(), self.garbler_inputs.len());
159 (0..inputs.len())
160 .zip(inputs)
161 .map(|(id, &x)| self.encode_garbler_input(x, id))
162 .collect()
163 }
164
165 pub fn encode_evaluator_inputs(&self, inputs: &[u16]) -> Vec<Wire> {
167 debug_assert_eq!(inputs.len(), self.evaluator_inputs.len());
168 (0..inputs.len())
169 .zip(inputs)
170 .map(|(id, &x)| self.encode_evaluator_input(x, id))
171 .collect()
172 }
173}
174
175#[derive(Debug)]
180pub struct GarbledReader {
181 blocks: Vec<Block>,
182 index: usize,
183}
184
185impl GarbledReader {
186 fn new(blocks: &[Block]) -> Self {
187 Self {
188 blocks: blocks.to_vec(),
189 index: 0,
190 }
191 }
192}
193
194impl std::io::Read for GarbledReader {
195 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
196 assert_eq!(buf.len() % 16, 0);
197 for data in buf.chunks_mut(16) {
198 let block: [u8; 16] = self.blocks[self.index].into();
199 for (a, b) in data.iter_mut().zip(block.iter()) {
200 *a = *b;
201 }
202 self.index += 1;
203 }
204 Ok(buf.len())
205 }
206}
207
208#[derive(Debug)]
210pub struct GarbledWriter {
211 blocks: Vec<Block>,
212}
213
214impl GarbledWriter {
215 pub fn new(ngates: Option<usize>) -> Self {
217 let blocks = if let Some(n) = ngates {
218 Vec::with_capacity(2 * n)
219 } else {
220 Vec::new()
221 };
222 Self { blocks }
223 }
224}
225
226impl std::io::Write for GarbledWriter {
227 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
228 for item in buf.chunks(16) {
229 let bytes: [u8; 16] = match item.try_into() {
230 Ok(bytes) => bytes,
231 Err(_) => {
232 return Err(std::io::Error::new(
233 std::io::ErrorKind::InvalidData,
234 "unable to map bytes to block",
235 ));
236 }
237 };
238 self.blocks.push(Block::from(bytes));
239 }
240 Ok(buf.len())
241 }
242 fn flush(&mut self) -> std::io::Result<()> {
243 Ok(())
244 }
245}