fancy_garbling/circuit/
arithmetic.rs

1use crate::{
2    FancyArithmetic, FancyBinary, HasModulus, check_binary,
3    circuit::{CircuitBuilder, CircuitRef, CircuitType, EvaluableCircuit},
4};
5use swanky_channel::Channel;
6
7/// Static representation of arithmetic computation supported by fancy garbling.
8#[derive(Clone, Debug, PartialEq)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10pub struct ArithmeticCircuit {
11    pub(crate) gates: Vec<ArithmeticGate>,
12    pub(crate) gate_moduli: Vec<u16>,
13    pub(crate) garbler_input_refs: Vec<CircuitRef>,
14    pub(crate) evaluator_input_refs: Vec<CircuitRef>,
15    pub(crate) const_refs: Vec<CircuitRef>,
16    pub(crate) output_refs: Vec<CircuitRef>,
17    pub(crate) num_nonfree_gates: usize,
18}
19
20/// Arithmetic computation supported by fancy garbling.
21///
22/// `id` represents the gate number. `out` gives the output wire index; if `out
23/// = None`, then we use the gate index as the output wire index.
24#[derive(Clone, Debug, PartialEq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub enum ArithmeticGate {
27    /// Input of garbler
28    GarblerInput {
29        /// Gate number
30        id: usize,
31    },
32    /// Input of evaluator
33    EvaluatorInput {
34        /// Gate number
35        id: usize,
36    },
37    /// Constant value
38    Constant {
39        /// Value of constant
40        val: u16,
41    },
42    /// Add gate
43    Add {
44        /// Reference to input 1
45        xref: CircuitRef,
46
47        /// Reference to input 2
48        yref: CircuitRef,
49
50        /// Output wire index
51        out: Option<usize>,
52    },
53    /// Sub gate
54    Sub {
55        /// Reference to input 1
56        xref: CircuitRef,
57
58        /// Reference to input 2
59        yref: CircuitRef,
60
61        /// Output wire index
62        out: Option<usize>,
63    },
64    /// Constant multiplication gate
65    Cmul {
66        /// Reference to input 1
67        xref: CircuitRef,
68
69        /// Constant to muiltiply by
70        c: u16,
71
72        /// Output wire index
73        out: Option<usize>,
74    },
75    /// Multiplication gate
76    Mul {
77        /// Reference to input 1
78        xref: CircuitRef,
79
80        /// Reference to input 2
81        yref: CircuitRef,
82
83        /// Gate number
84        id: usize,
85
86        /// Output wire index
87        out: Option<usize>,
88    },
89    /// Projection gate
90    Proj {
91        /// Reference to input 1
92        xref: CircuitRef,
93
94        /// Projection truth table
95        tt: Vec<u16>,
96
97        /// Gate number
98        id: usize,
99
100        /// Output wire index
101        out: Option<usize>,
102    },
103}
104
105impl std::fmt::Display for ArithmeticGate {
106    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107        match self {
108            Self::GarblerInput { id } => write!(f, "GarblerInput {}", id),
109            Self::EvaluatorInput { id } => write!(f, "EvaluatorInput {}", id),
110            Self::Constant { val } => write!(f, "Constant {}", val),
111            Self::Add { xref, yref, out } => write!(f, "Add ( {}, {}, {:?} )", xref, yref, out),
112            Self::Sub { xref, yref, out } => write!(f, "Sub ( {}, {}, {:?} )", xref, yref, out),
113            Self::Cmul { xref, c, out } => write!(f, "Cmul ( {}, {}, {:?} )", xref, c, out),
114            Self::Mul {
115                xref,
116                yref,
117                id,
118                out,
119            } => write!(f, "Mul ( {}, {}, {}, {:?} )", xref, yref, id, out),
120            Self::Proj { xref, tt, id, out } => {
121                write!(f, "Proj ( {}, {:?}, {}, {:?} )", xref, tt, id, out)
122            }
123        }
124    }
125}
126
127impl<F: FancyArithmetic> EvaluableCircuit<F> for ArithmeticCircuit {
128    fn eval_to_wirelabels(
129        &self,
130        f: &mut F,
131        garbler_inputs: &[F::Item],
132        evaluator_inputs: &[F::Item],
133        channel: &mut Channel,
134    ) -> swanky_error::Result<Vec<F::Item>> {
135        let mut cache: Vec<Option<F::Item>> = vec![None; self.gates.len()];
136        for (i, gate) in self.gates.iter().enumerate() {
137            let q = self.modulus(i);
138            let (zref_, val) = match *gate {
139                ArithmeticGate::GarblerInput { id } => (None, garbler_inputs[id].clone()),
140                ArithmeticGate::EvaluatorInput { id } => {
141                    assert!(
142                        id < evaluator_inputs.len(),
143                        "id={} ev_inps.len()={}",
144                        id,
145                        evaluator_inputs.len()
146                    );
147                    (None, evaluator_inputs[id].clone())
148                }
149                ArithmeticGate::Constant { val } => (None, f.constant(val, q, channel)?),
150                ArithmeticGate::Add { xref, yref, out } => (
151                    out,
152                    f.add(
153                        cache[xref.ix].as_ref().unwrap(),
154                        cache[yref.ix].as_ref().unwrap(),
155                    ),
156                ),
157                ArithmeticGate::Sub { xref, yref, out } => (
158                    out,
159                    f.sub(
160                        cache[xref.ix].as_ref().unwrap(),
161                        cache[yref.ix].as_ref().unwrap(),
162                    ),
163                ),
164                ArithmeticGate::Cmul { xref, c, out } => {
165                    (out, f.cmul(cache[xref.ix].as_ref().unwrap(), c))
166                }
167                ArithmeticGate::Proj {
168                    xref, ref tt, out, ..
169                } => (
170                    out,
171                    f.proj(
172                        cache[xref.ix].as_ref().unwrap(),
173                        q,
174                        Some(tt.to_vec()),
175                        channel,
176                    )?,
177                ),
178                ArithmeticGate::Mul {
179                    xref, yref, out, ..
180                } => (
181                    out,
182                    f.mul(
183                        cache[xref.ix].as_ref().unwrap(),
184                        cache[yref.ix].as_ref().unwrap(),
185                        channel,
186                    )?,
187                ),
188            };
189            cache[zref_.unwrap_or(i)] = Some(val);
190        }
191        let mut outputs = Vec::with_capacity(self.noutputs());
192        for r in self.get_output_refs().iter() {
193            let wirelabel = cache[r.ix].as_ref().unwrap();
194            outputs.push(wirelabel.clone());
195        }
196        Ok(outputs)
197    }
198}
199
200impl CircuitType for ArithmeticCircuit {
201    type Gate = ArithmeticGate;
202
203    fn new(ngates: Option<usize>) -> ArithmeticCircuit {
204        let gates = Vec::with_capacity(ngates.unwrap_or(0));
205        ArithmeticCircuit {
206            gates,
207            garbler_input_refs: Vec::new(),
208            evaluator_input_refs: Vec::new(),
209            const_refs: Vec::new(),
210            output_refs: Vec::new(),
211            gate_moduli: Vec::new(),
212            num_nonfree_gates: 0,
213        }
214    }
215
216    fn push_gates(&mut self, gate: Self::Gate) {
217        self.gates.push(gate)
218    }
219
220    fn push_const_ref(&mut self, xref: CircuitRef) {
221        self.const_refs.push(xref)
222    }
223
224    fn push_output_ref(&mut self, xref: CircuitRef) {
225        self.output_refs.push(xref)
226    }
227
228    fn push_garbler_input_ref(&mut self, xref: CircuitRef) {
229        self.garbler_input_refs.push(xref)
230    }
231
232    fn push_modulus(&mut self, modulus: u16) {
233        self.gate_moduli.push(modulus)
234    }
235
236    fn push_evaluator_input_ref(&mut self, xref: CircuitRef) {
237        self.evaluator_input_refs.push(xref)
238    }
239
240    fn increment_nonfree_gates(&mut self) {
241        self.num_nonfree_gates += 1;
242    }
243
244    fn get_num_nonfree_gates(&self) -> usize {
245        self.num_nonfree_gates
246    }
247
248    fn get_output_refs(&self) -> &[CircuitRef] {
249        &self.output_refs
250    }
251
252    fn get_garbler_input_refs(&self) -> &[CircuitRef] {
253        &self.garbler_input_refs
254    }
255
256    fn get_evaluator_input_refs(&self) -> &[CircuitRef] {
257        &self.evaluator_input_refs
258    }
259
260    fn garbler_input_mod(&self, i: usize) -> u16 {
261        let r = self.garbler_input_refs[i];
262        r.modulus()
263    }
264
265    fn evaluator_input_mod(&self, i: usize) -> u16 {
266        let r = self.evaluator_input_refs[i];
267        r.modulus()
268    }
269}
270
271impl ArithmeticCircuit {
272    /// Return the modulus of the gate indexed by `i`.
273    #[inline]
274    pub fn modulus(&self, i: usize) -> u16 {
275        self.gate_moduli[i]
276    }
277}
278
279impl FancyBinary for CircuitBuilder<ArithmeticCircuit> {
280    fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
281        check_binary!(x);
282        check_binary!(y);
283
284        self.add(x, y)
285    }
286
287    fn and(
288        &mut self,
289        x: &Self::Item,
290        y: &Self::Item,
291        channel: &mut Channel,
292    ) -> swanky_error::Result<Self::Item> {
293        check_binary!(x);
294        check_binary!(y);
295
296        self.mul(x, y, channel)
297    }
298
299    fn negate(&mut self, x: &Self::Item) -> Self::Item {
300        check_binary!(x);
301
302        let one = self.lookup_constant(1, 2);
303
304        self.xor(x, &one)
305    }
306}
307
308impl FancyArithmetic for CircuitBuilder<ArithmeticCircuit> {
309    fn add(&mut self, xref: &CircuitRef, yref: &CircuitRef) -> CircuitRef {
310        assert_eq!(xref.modulus(), yref.modulus());
311        let gate = ArithmeticGate::Add {
312            xref: *xref,
313            yref: *yref,
314            out: None,
315        };
316        self.gate(gate, xref.modulus())
317    }
318
319    fn sub(&mut self, xref: &CircuitRef, yref: &CircuitRef) -> CircuitRef {
320        assert_eq!(xref.modulus(), yref.modulus());
321        let gate = ArithmeticGate::Sub {
322            xref: *xref,
323            yref: *yref,
324            out: None,
325        };
326        self.gate(gate, xref.modulus())
327    }
328
329    fn cmul(&mut self, xref: &CircuitRef, c: u16) -> CircuitRef {
330        self.gate(
331            ArithmeticGate::Cmul {
332                xref: *xref,
333                c,
334                out: None,
335            },
336            xref.modulus(),
337        )
338    }
339
340    fn proj(
341        &mut self,
342        xref: &CircuitRef,
343        output_modulus: u16,
344        tt: Option<Vec<u16>>,
345        _: &mut Channel,
346    ) -> swanky_error::Result<CircuitRef> {
347        assert!(tt.is_some(), "`tt` must not be `None`");
348        let tt = tt.unwrap();
349        assert!(
350            tt.len() >= xref.modulus() as usize,
351            "`tt` not large enough for `x`s modulus"
352        );
353        assert!(
354            tt.iter().all(|&x| x < output_modulus),
355            "`tt` value larger than `q`"
356        );
357        let gate = ArithmeticGate::Proj {
358            xref: *xref,
359            tt: tt.to_vec(),
360            id: self.get_next_ciphertext_id(),
361            out: None,
362        };
363        Ok(self.gate(gate, output_modulus))
364    }
365
366    fn mul(
367        &mut self,
368        xref: &CircuitRef,
369        yref: &CircuitRef,
370        _channel: &mut Channel,
371    ) -> swanky_error::Result<CircuitRef> {
372        if xref.modulus() < yref.modulus() {
373            return self.mul(yref, xref, _channel);
374        }
375
376        let gate = ArithmeticGate::Mul {
377            xref: *xref,
378            yref: *yref,
379            id: self.get_next_ciphertext_id(),
380            out: None,
381        };
382
383        Ok(self.gate(gate, xref.modulus()))
384    }
385}