1use crate::{
2 FancyArithmetic, FancyBinary, HasModulus, check_binary,
3 circuit::{CircuitBuilder, CircuitRef, CircuitType, EvaluableCircuit},
4};
5use swanky_channel::Channel;
6
7#[derive(Clone, Debug, PartialEq)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10pub struct ArithmeticCircuit {
11 pub(crate) gates: Vec<ArithmeticGate>,
12 pub(crate) gate_moduli: Vec<u16>,
13 pub(crate) garbler_input_refs: Vec<CircuitRef>,
14 pub(crate) evaluator_input_refs: Vec<CircuitRef>,
15 pub(crate) const_refs: Vec<CircuitRef>,
16 pub(crate) output_refs: Vec<CircuitRef>,
17 pub(crate) num_nonfree_gates: usize,
18}
19
20#[derive(Clone, Debug, PartialEq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub enum ArithmeticGate {
27 GarblerInput {
29 id: usize,
31 },
32 EvaluatorInput {
34 id: usize,
36 },
37 Constant {
39 val: u16,
41 },
42 Add {
44 xref: CircuitRef,
46
47 yref: CircuitRef,
49
50 out: Option<usize>,
52 },
53 Sub {
55 xref: CircuitRef,
57
58 yref: CircuitRef,
60
61 out: Option<usize>,
63 },
64 Cmul {
66 xref: CircuitRef,
68
69 c: u16,
71
72 out: Option<usize>,
74 },
75 Mul {
77 xref: CircuitRef,
79
80 yref: CircuitRef,
82
83 id: usize,
85
86 out: Option<usize>,
88 },
89 Proj {
91 xref: CircuitRef,
93
94 tt: Vec<u16>,
96
97 id: usize,
99
100 out: Option<usize>,
102 },
103}
104
105impl std::fmt::Display for ArithmeticGate {
106 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107 match self {
108 Self::GarblerInput { id } => write!(f, "GarblerInput {}", id),
109 Self::EvaluatorInput { id } => write!(f, "EvaluatorInput {}", id),
110 Self::Constant { val } => write!(f, "Constant {}", val),
111 Self::Add { xref, yref, out } => write!(f, "Add ( {}, {}, {:?} )", xref, yref, out),
112 Self::Sub { xref, yref, out } => write!(f, "Sub ( {}, {}, {:?} )", xref, yref, out),
113 Self::Cmul { xref, c, out } => write!(f, "Cmul ( {}, {}, {:?} )", xref, c, out),
114 Self::Mul {
115 xref,
116 yref,
117 id,
118 out,
119 } => write!(f, "Mul ( {}, {}, {}, {:?} )", xref, yref, id, out),
120 Self::Proj { xref, tt, id, out } => {
121 write!(f, "Proj ( {}, {:?}, {}, {:?} )", xref, tt, id, out)
122 }
123 }
124 }
125}
126
127impl<F: FancyArithmetic> EvaluableCircuit<F> for ArithmeticCircuit {
128 fn eval_to_wirelabels(
129 &self,
130 f: &mut F,
131 garbler_inputs: &[F::Item],
132 evaluator_inputs: &[F::Item],
133 channel: &mut Channel,
134 ) -> swanky_error::Result<Vec<F::Item>> {
135 let mut cache: Vec<Option<F::Item>> = vec![None; self.gates.len()];
136 for (i, gate) in self.gates.iter().enumerate() {
137 let q = self.modulus(i);
138 let (zref_, val) = match *gate {
139 ArithmeticGate::GarblerInput { id } => (None, garbler_inputs[id].clone()),
140 ArithmeticGate::EvaluatorInput { id } => {
141 assert!(
142 id < evaluator_inputs.len(),
143 "id={} ev_inps.len()={}",
144 id,
145 evaluator_inputs.len()
146 );
147 (None, evaluator_inputs[id].clone())
148 }
149 ArithmeticGate::Constant { val } => (None, f.constant(val, q, channel)?),
150 ArithmeticGate::Add { xref, yref, out } => (
151 out,
152 f.add(
153 cache[xref.ix].as_ref().unwrap(),
154 cache[yref.ix].as_ref().unwrap(),
155 ),
156 ),
157 ArithmeticGate::Sub { xref, yref, out } => (
158 out,
159 f.sub(
160 cache[xref.ix].as_ref().unwrap(),
161 cache[yref.ix].as_ref().unwrap(),
162 ),
163 ),
164 ArithmeticGate::Cmul { xref, c, out } => {
165 (out, f.cmul(cache[xref.ix].as_ref().unwrap(), c))
166 }
167 ArithmeticGate::Proj {
168 xref, ref tt, out, ..
169 } => (
170 out,
171 f.proj(
172 cache[xref.ix].as_ref().unwrap(),
173 q,
174 Some(tt.to_vec()),
175 channel,
176 )?,
177 ),
178 ArithmeticGate::Mul {
179 xref, yref, out, ..
180 } => (
181 out,
182 f.mul(
183 cache[xref.ix].as_ref().unwrap(),
184 cache[yref.ix].as_ref().unwrap(),
185 channel,
186 )?,
187 ),
188 };
189 cache[zref_.unwrap_or(i)] = Some(val);
190 }
191 let mut outputs = Vec::with_capacity(self.noutputs());
192 for r in self.get_output_refs().iter() {
193 let wirelabel = cache[r.ix].as_ref().unwrap();
194 outputs.push(wirelabel.clone());
195 }
196 Ok(outputs)
197 }
198}
199
200impl CircuitType for ArithmeticCircuit {
201 type Gate = ArithmeticGate;
202
203 fn new(ngates: Option<usize>) -> ArithmeticCircuit {
204 let gates = Vec::with_capacity(ngates.unwrap_or(0));
205 ArithmeticCircuit {
206 gates,
207 garbler_input_refs: Vec::new(),
208 evaluator_input_refs: Vec::new(),
209 const_refs: Vec::new(),
210 output_refs: Vec::new(),
211 gate_moduli: Vec::new(),
212 num_nonfree_gates: 0,
213 }
214 }
215
216 fn push_gates(&mut self, gate: Self::Gate) {
217 self.gates.push(gate)
218 }
219
220 fn push_const_ref(&mut self, xref: CircuitRef) {
221 self.const_refs.push(xref)
222 }
223
224 fn push_output_ref(&mut self, xref: CircuitRef) {
225 self.output_refs.push(xref)
226 }
227
228 fn push_garbler_input_ref(&mut self, xref: CircuitRef) {
229 self.garbler_input_refs.push(xref)
230 }
231
232 fn push_modulus(&mut self, modulus: u16) {
233 self.gate_moduli.push(modulus)
234 }
235
236 fn push_evaluator_input_ref(&mut self, xref: CircuitRef) {
237 self.evaluator_input_refs.push(xref)
238 }
239
240 fn increment_nonfree_gates(&mut self) {
241 self.num_nonfree_gates += 1;
242 }
243
244 fn get_num_nonfree_gates(&self) -> usize {
245 self.num_nonfree_gates
246 }
247
248 fn get_output_refs(&self) -> &[CircuitRef] {
249 &self.output_refs
250 }
251
252 fn get_garbler_input_refs(&self) -> &[CircuitRef] {
253 &self.garbler_input_refs
254 }
255
256 fn get_evaluator_input_refs(&self) -> &[CircuitRef] {
257 &self.evaluator_input_refs
258 }
259
260 fn garbler_input_mod(&self, i: usize) -> u16 {
261 let r = self.garbler_input_refs[i];
262 r.modulus()
263 }
264
265 fn evaluator_input_mod(&self, i: usize) -> u16 {
266 let r = self.evaluator_input_refs[i];
267 r.modulus()
268 }
269}
270
271impl ArithmeticCircuit {
272 #[inline]
274 pub fn modulus(&self, i: usize) -> u16 {
275 self.gate_moduli[i]
276 }
277}
278
279impl FancyBinary for CircuitBuilder<ArithmeticCircuit> {
280 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
281 check_binary!(x);
282 check_binary!(y);
283
284 self.add(x, y)
285 }
286
287 fn and(
288 &mut self,
289 x: &Self::Item,
290 y: &Self::Item,
291 channel: &mut Channel,
292 ) -> swanky_error::Result<Self::Item> {
293 check_binary!(x);
294 check_binary!(y);
295
296 self.mul(x, y, channel)
297 }
298
299 fn negate(&mut self, x: &Self::Item) -> Self::Item {
300 check_binary!(x);
301
302 let one = self.lookup_constant(1, 2);
303
304 self.xor(x, &one)
305 }
306}
307
308impl FancyArithmetic for CircuitBuilder<ArithmeticCircuit> {
309 fn add(&mut self, xref: &CircuitRef, yref: &CircuitRef) -> CircuitRef {
310 assert_eq!(xref.modulus(), yref.modulus());
311 let gate = ArithmeticGate::Add {
312 xref: *xref,
313 yref: *yref,
314 out: None,
315 };
316 self.gate(gate, xref.modulus())
317 }
318
319 fn sub(&mut self, xref: &CircuitRef, yref: &CircuitRef) -> CircuitRef {
320 assert_eq!(xref.modulus(), yref.modulus());
321 let gate = ArithmeticGate::Sub {
322 xref: *xref,
323 yref: *yref,
324 out: None,
325 };
326 self.gate(gate, xref.modulus())
327 }
328
329 fn cmul(&mut self, xref: &CircuitRef, c: u16) -> CircuitRef {
330 self.gate(
331 ArithmeticGate::Cmul {
332 xref: *xref,
333 c,
334 out: None,
335 },
336 xref.modulus(),
337 )
338 }
339
340 fn proj(
341 &mut self,
342 xref: &CircuitRef,
343 output_modulus: u16,
344 tt: Option<Vec<u16>>,
345 _: &mut Channel,
346 ) -> swanky_error::Result<CircuitRef> {
347 assert!(tt.is_some(), "`tt` must not be `None`");
348 let tt = tt.unwrap();
349 assert!(
350 tt.len() >= xref.modulus() as usize,
351 "`tt` not large enough for `x`s modulus"
352 );
353 assert!(
354 tt.iter().all(|&x| x < output_modulus),
355 "`tt` value larger than `q`"
356 );
357 let gate = ArithmeticGate::Proj {
358 xref: *xref,
359 tt: tt.to_vec(),
360 id: self.get_next_ciphertext_id(),
361 out: None,
362 };
363 Ok(self.gate(gate, output_modulus))
364 }
365
366 fn mul(
367 &mut self,
368 xref: &CircuitRef,
369 yref: &CircuitRef,
370 _channel: &mut Channel,
371 ) -> swanky_error::Result<CircuitRef> {
372 if xref.modulus() < yref.modulus() {
373 return self.mul(yref, xref, _channel);
374 }
375
376 let gate = ArithmeticGate::Mul {
377 xref: *xref,
378 yref: *yref,
379 id: self.get_next_ciphertext_id(),
380 out: None,
381 };
382
383 Ok(self.gate(gate, xref.modulus()))
384 }
385}