Skip to main content

fancy_garbling/
circuit.rs

1use crate::{BinaryBundle, Bundle, CrtBundle, HasModulus, fancy::Fancy};
2use itertools::Itertools;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6mod binary;
7pub use binary::{BinaryCircuit, BinaryGate};
8
9/// Trait for flattening the output of a [`Circuit`] into a vector of wires.
10pub trait Flatten {
11    /// The type of the elements in the output vector.
12    type Item;
13
14    /// Flatten a set of wires into a single vector of wires.
15    fn flatten(self) -> Vec<Self::Item>;
16}
17
18impl<T: Clone + HasModulus> Flatten for Vec<T> {
19    type Item = T;
20
21    fn flatten(self) -> Vec<Self::Item> {
22        self
23    }
24}
25
26impl<T: Clone + HasModulus> Flatten for T {
27    type Item = T;
28
29    fn flatten(self) -> Vec<Self::Item> {
30        vec![self]
31    }
32}
33
34impl<T: Clone + HasModulus> Flatten for (T, T) {
35    type Item = T;
36
37    fn flatten(self) -> Vec<Self::Item> {
38        vec![self.0]
39    }
40}
41
42impl<T: Clone + HasModulus> Flatten for Bundle<T> {
43    type Item = T;
44
45    fn flatten(self) -> Vec<Self::Item> {
46        self.wires().to_vec()
47    }
48}
49
50impl<T: Clone + HasModulus> Flatten for BinaryBundle<T> {
51    type Item = T;
52
53    fn flatten(self) -> Vec<T> {
54        self.wires().to_vec()
55    }
56}
57
58impl<T: Clone + HasModulus> Flatten for CrtBundle<T> {
59    type Item = T;
60
61    fn flatten(self) -> Vec<T> {
62        self.extract().wires().to_vec()
63    }
64}
65
66impl<T: Clone + HasModulus> Flatten for Vec<CrtBundle<T>> {
67    type Item = T;
68
69    fn flatten(self) -> Vec<Self::Item> {
70        self.into_iter().map(|bundle| bundle.flatten()).concat()
71    }
72}
73
74impl<T: Clone + HasModulus> Flatten for (T, BinaryBundle<T>) {
75    type Item = T;
76
77    fn flatten(self) -> Vec<Self::Item> {
78        [vec![self.0], self.1.flatten()].concat()
79    }
80}
81
82impl<T: Clone + HasModulus> Flatten for (BinaryBundle<T>, T) {
83    type Item = T;
84
85    fn flatten(self) -> Vec<Self::Item> {
86        [self.0.flatten(), vec![self.1]].concat()
87    }
88}
89
90impl<T: Clone + HasModulus, const N: usize> Flatten for [T; N] {
91    type Item = T;
92
93    fn flatten(self) -> Vec<Self::Item> {
94        self.to_vec()
95    }
96}
97
98/// Trait for defining computations over [`Fancy`] objects.
99///
100/// A `Circuit` computation is defined by a [`Circuit::Input`] associated type,
101/// a [`Circuit::Output`] associated type, and a [`Circuit::execute`] method
102/// that maps [`Circuit::Input`] to [`Circuit::Output`]. The body of
103/// [`Circuit::execute`] may use other `Circuit`s internally.
104///
105/// For mapping arbitrary inputs into the correct `Circuit` input
106/// representation, use the [`CircuitInputMapper`] trait.
107///
108/// # Example
109/// Below is a simple circuit computing an add gate. The computation is defined
110/// in `execute` by directly calling operations on the underlying [`Fancy`]
111/// backend ([`crate::FancyArithmetic`] in this example).
112/// ```
113/// # use fancy_garbling::{FancyArithmetic, Circuit};
114/// # use swanky_channel::Channel;
115/// # use swanky_error::Result;
116/// struct AddCircuit;
117/// impl<F: FancyArithmetic> Circuit<F> for AddCircuit {
118///     type Input = (F::Item, F::Item);
119///     type Output = F::Item;
120///
121///     fn execute(
122///         &self,
123///         backend: &mut F,
124///         inputs: Self::Input,
125///         channel: &mut Channel,
126///     ) -> Result<Self::Output> {
127///         Ok(backend.add(&inputs.0, &inputs.1))
128///     }
129/// }
130/// ```
131/// Given `AddCircuit`, any object instantiating the required [`Fancy`] traits
132/// can evaluate the circuit by calling `AddMany.execute(...)`.
133pub trait Circuit<F: Fancy> {
134    /// The input type of the circuit.
135    type Input;
136    /// The output type of the circuit.
137    ///
138    /// The [`Flatten`] trait allows the output type to be converted into a
139    /// `Vec<F::Item>`.
140    type Output: Flatten<Item = F::Item>;
141
142    /// Execute a circuit on a given [`Fancy`] backend using the provided inputs.
143    fn execute(
144        &self,
145        backend: &mut F,
146        inputs: Self::Input,
147        channel: &mut Channel,
148    ) -> Result<Self::Output>;
149}
150
151/// Trait for defining input-size-dependent [`Circuit`]s.
152///
153/// The [`Circuit`] trait allows one to write circuits for arbitrary-length
154/// inputs, which becomes a problem when needing to be used in, for example, a
155/// garbled circuit protocol, where the input length needs to be know during the
156/// Oblivious Transfer phase of the protocol. This is where `CircuitInputMapper`
157/// comes in: it provides a [`CircuitInputMapper::map`] method for mapping
158/// vectors of inputs to the appropriate input type as required to run
159/// [`Circuit::execute`], alongside [`CircuitInputMapper::ninputs`] for
160/// determining the number of inputs for the circuit. Finally,
161/// [`CircuitInputMapper::modulus`] outputs the particular modulus required for
162/// the `i`th input wire.
163///
164/// While certain [`Fancy`] instantiations can evaluate [`Circuit`]s directly,
165/// several, including garbled circuits and zero knowledge protocols, operate
166/// over `CircuitInputMapper`s instead, and any [`Circuit`] to be run under
167/// these protocols needs to implement `CircuitInputMapper` as well.
168///
169/// # Example
170/// The below code extends the `AddCircuit` example from the [`Circuit`]
171/// documentation to support mapping a vector of inputs into the appropriate
172/// input type for the given circuit.
173/// ```
174/// # use fancy_garbling::{FancyArithmetic, Circuit, CircuitInputMapper};
175/// # use swanky_channel::Channel;
176/// # use swanky_error::Result;
177/// # struct AddCircuit;
178/// # impl<F: FancyArithmetic> Circuit<F> for AddCircuit {
179/// #     type Input = (F::Item, F::Item);
180/// #     type Output = F::Item;
181/// #
182/// #     fn execute(
183/// #         &self,
184/// #         backend: &mut F,
185/// #         inputs: Self::Input,
186/// #         channel: &mut Channel,
187/// #     ) -> Result<Self::Output> {
188/// #         Ok(backend.add(&inputs.0, &inputs.1))
189/// #     }
190/// # }
191/// impl<F: FancyArithmetic> CircuitInputMapper<F> for AddCircuit {
192///     fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
193///         assert_eq!(inputs.len(), 2);
194///         (inputs[0].clone(), inputs[1].clone())
195///     }
196///
197///     fn ninputs(&self) -> usize {
198///         2
199///     }
200///
201///     fn modulus(&self, _: usize) -> u16 {
202///         2
203///     }
204/// }
205/// ```
206pub trait CircuitInputMapper<F: Fancy>: Circuit<F> {
207    /// Map a vector of inputs to [`Circuit::Input`].
208    ///
209    /// # Panics
210    /// This panics if the number of inputs does not match the expected input
211    /// size.
212    fn map(&self, inputs: Vec<F::Item>) -> Self::Input;
213    /// The number of inputs to provide to [`CircuitInputMapper::map`].
214    fn ninputs(&self) -> usize;
215    /// The modulus of the `i`th input.
216    fn modulus(&self, i: usize) -> u16;
217}
218
219pub mod test_circuits {
220    //! A collection of test circuits.
221
222    pub mod fancy {
223        //! Circuits that test [`Fancy`].
224
225        use crate::{
226            Fancy,
227            circuit::{Circuit, CircuitInputMapper},
228        };
229        use swanky_channel::Channel;
230        use swanky_error::Result;
231
232        /// Circuit for testing [`Fancy::constant`] on binary values.
233        pub struct TestBinaryConstant;
234        impl<F: Fancy> Circuit<F> for TestBinaryConstant {
235            type Input = ();
236            type Output = Vec<F::Item>;
237
238            fn execute(
239                &self,
240                backend: &mut F,
241                _: Self::Input,
242                channel: &mut Channel,
243            ) -> Result<Self::Output> {
244                let outputs = vec![
245                    backend.constant(0, 2, channel)?,
246                    backend.constant(1, 2, channel)?,
247                ];
248                Ok(outputs)
249            }
250        }
251        impl<F: Fancy> CircuitInputMapper<F> for TestBinaryConstant {
252            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
253                assert!(inputs.is_empty());
254            }
255
256            fn ninputs(&self) -> usize {
257                0
258            }
259
260            fn modulus(&self, _: usize) -> u16 {
261                2
262            }
263        }
264    }
265
266    pub mod binary {
267        //! Circuits that test [`FancyBinary`].
268
269        use crate::{
270            FancyBinary,
271            circuit::{Circuit, CircuitInputMapper},
272            circuits::binary::{AndMany, OrMany, XorMany},
273        };
274        use swanky_channel::Channel;
275        use swanky_error::Result;
276
277        /// Circuit for testing [`FancyBinary::negate`].
278        pub struct TestNegateGate;
279        impl<F: FancyBinary> Circuit<F> for TestNegateGate {
280            type Input = F::Item;
281            type Output = F::Item;
282
283            fn execute(
284                &self,
285                backend: &mut F,
286                input: Self::Input,
287                _: &mut Channel,
288            ) -> Result<Self::Output> {
289                Ok(backend.negate(&input))
290            }
291        }
292        impl<F: FancyBinary> CircuitInputMapper<F> for TestNegateGate {
293            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
294                assert_eq!(inputs.len(), 1);
295                inputs[0].clone()
296            }
297
298            fn ninputs(&self) -> usize {
299                1
300            }
301
302            fn modulus(&self, _: usize) -> u16 {
303                2
304            }
305        }
306
307        /// Circuit for testing [`FancyBinary::and`].
308        pub struct TestAndGate;
309        impl<F: FancyBinary> Circuit<F> for TestAndGate {
310            type Input = (F::Item, F::Item);
311            type Output = F::Item;
312
313            fn execute(
314                &self,
315                backend: &mut F,
316                inputs: Self::Input,
317                channel: &mut Channel,
318            ) -> Result<Self::Output> {
319                backend.and(&inputs.0, &inputs.1, channel)
320            }
321        }
322        impl<F: FancyBinary> CircuitInputMapper<F> for TestAndGate {
323            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
324                assert_eq!(inputs.len(), 2);
325                (inputs[0].clone(), inputs[1].clone())
326            }
327
328            fn ninputs(&self) -> usize {
329                2
330            }
331
332            fn modulus(&self, _: usize) -> u16 {
333                2
334            }
335        }
336
337        /// Circuit for testing [`AndMany`].
338        pub struct TestAndGateFanN(pub usize);
339        impl<F: FancyBinary> Circuit<F> for TestAndGateFanN {
340            type Input = Vec<F::Item>;
341            type Output = F::Item;
342
343            fn execute(
344                &self,
345                backend: &mut F,
346                inputs: Self::Input,
347                channel: &mut Channel,
348            ) -> Result<Self::Output> {
349                AndMany::new().execute(backend, inputs.as_slice(), channel)
350            }
351        }
352
353        impl<F: FancyBinary> CircuitInputMapper<F> for TestAndGateFanN {
354            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
355                assert_eq!(inputs.len(), self.0);
356                inputs
357            }
358
359            fn ninputs(&self) -> usize {
360                self.0
361            }
362
363            fn modulus(&self, _: usize) -> u16 {
364                2
365            }
366        }
367
368        /// Circuit for testing [`OrMany`].
369        pub struct TestOrGateFanN(pub usize);
370        impl<F: FancyBinary> Circuit<F> for TestOrGateFanN {
371            type Input = Vec<F::Item>;
372            type Output = F::Item;
373
374            fn execute(
375                &self,
376                backend: &mut F,
377                inputs: Self::Input,
378                channel: &mut Channel,
379            ) -> Result<Self::Output> {
380                OrMany::new().execute(backend, inputs.as_slice(), channel)
381            }
382        }
383        impl<F: FancyBinary> CircuitInputMapper<F> for TestOrGateFanN {
384            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
385                assert_eq!(inputs.len(), self.0);
386                inputs
387            }
388
389            fn ninputs(&self) -> usize {
390                self.0
391            }
392
393            fn modulus(&self, _: usize) -> u16 {
394                2
395            }
396        }
397
398        /// Circuit for testing [`XorMany`].
399        pub struct TestXorGateFanN(pub usize);
400        impl<F: FancyBinary> Circuit<F> for TestXorGateFanN {
401            type Input = Vec<F::Item>;
402            type Output = F::Item;
403
404            fn execute(
405                &self,
406                backend: &mut F,
407                inputs: Self::Input,
408                channel: &mut Channel,
409            ) -> Result<Self::Output> {
410                XorMany::new().execute(backend, inputs.as_slice(), channel)
411            }
412        }
413        impl<F: FancyBinary> CircuitInputMapper<F> for TestXorGateFanN {
414            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
415                assert_eq!(inputs.len(), self.0);
416                inputs
417            }
418
419            fn ninputs(&self) -> usize {
420                self.0
421            }
422
423            fn modulus(&self, _: usize) -> u16 {
424                2
425            }
426        }
427    }
428
429    pub mod arithmetic {
430        //! Circuits that test [`FancyArithmetic`].
431
432        use crate::{
433            FancyArithmetic,
434            circuit::{Circuit, CircuitInputMapper},
435            circuits::arithmetic::AddMany,
436        };
437        use swanky_channel::Channel;
438        use swanky_error::Result;
439
440        /// Circuit for testing [`FancyArithmetic::add`].
441        pub struct TestAddition(pub u16);
442        impl<F: FancyArithmetic> Circuit<F> for TestAddition {
443            type Input = (F::Item, F::Item);
444            type Output = F::Item;
445
446            fn execute(
447                &self,
448                backend: &mut F,
449                inputs: Self::Input,
450                _: &mut Channel,
451            ) -> Result<Self::Output> {
452                Ok(backend.add(&inputs.0, &inputs.1))
453            }
454        }
455        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestAddition {
456            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
457                assert_eq!(inputs.len(), 2);
458                (inputs[0].clone(), inputs[1].clone())
459            }
460
461            fn ninputs(&self) -> usize {
462                2
463            }
464
465            fn modulus(&self, _: usize) -> u16 {
466                self.0
467            }
468        }
469
470        /// Circuit for testing [`AddMany`].
471        pub struct TestAddMany(pub u16, pub usize);
472        impl<F: FancyArithmetic> Circuit<F> for TestAddMany {
473            type Input = Vec<F::Item>;
474            type Output = F::Item;
475
476            fn execute(
477                &self,
478                backend: &mut F,
479                inputs: Self::Input,
480                channel: &mut Channel,
481            ) -> Result<Self::Output> {
482                AddMany::new().execute(backend, inputs.as_slice(), channel)
483            }
484        }
485        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestAddMany {
486            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
487                assert_eq!(inputs.len(), self.1);
488                inputs
489            }
490
491            fn ninputs(&self) -> usize {
492                self.1
493            }
494
495            fn modulus(&self, _: usize) -> u16 {
496                self.0
497            }
498        }
499
500        /// Circuit for testing [`FancyArithmetic::sub`].
501        pub struct TestSubtraction(pub u16);
502        impl<F: FancyArithmetic> Circuit<F> for TestSubtraction {
503            type Input = (F::Item, F::Item);
504            type Output = F::Item;
505
506            fn execute(
507                &self,
508                backend: &mut F,
509                inputs: Self::Input,
510                _: &mut Channel,
511            ) -> Result<Self::Output> {
512                Ok(backend.sub(&inputs.0, &inputs.1))
513            }
514        }
515        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestSubtraction {
516            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
517                assert_eq!(inputs.len(), 2);
518                (inputs[0].clone(), inputs[1].clone())
519            }
520
521            fn ninputs(&self) -> usize {
522                2
523            }
524
525            fn modulus(&self, _: usize) -> u16 {
526                self.0
527            }
528        }
529
530        /// Circuit for testing [`FancyArithmetic::mul`].
531        pub struct TestMulGate(pub u16);
532        impl<F: FancyArithmetic> Circuit<F> for TestMulGate {
533            type Input = (F::Item, F::Item);
534            type Output = F::Item;
535
536            fn execute(
537                &self,
538                backend: &mut F,
539                inputs: Self::Input,
540                channel: &mut Channel,
541            ) -> Result<Self::Output> {
542                backend.mul(&inputs.0, &inputs.1, channel)
543            }
544        }
545        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestMulGate {
546            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
547                assert_eq!(inputs.len(), 2);
548                (inputs[0].clone(), inputs[1].clone())
549            }
550
551            fn ninputs(&self) -> usize {
552                2
553            }
554
555            fn modulus(&self, _: usize) -> u16 {
556                self.0
557            }
558        }
559
560        /// Circuit for testing [`FancyArithmetic::mul`] using two different moduli
561        /// for the inputs.
562        pub struct TestMulGateUnequalMods(pub [u16; 2]);
563        impl<F: FancyArithmetic> Circuit<F> for TestMulGateUnequalMods {
564            type Input = (F::Item, F::Item);
565            type Output = F::Item;
566
567            fn execute(
568                &self,
569                backend: &mut F,
570                inputs: Self::Input,
571                channel: &mut Channel,
572            ) -> Result<Self::Output> {
573                backend.mul(&inputs.0, &inputs.1, channel)
574            }
575        }
576        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestMulGateUnequalMods {
577            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
578                assert_eq!(inputs.len(), 2);
579                (inputs[0].clone(), inputs[1].clone())
580            }
581
582            fn ninputs(&self) -> usize {
583                2
584            }
585
586            fn modulus(&self, i: usize) -> u16 {
587                self.0[i]
588            }
589        }
590
591        /// Circuit for testing [`FancyArithmetic::cmul`].
592        pub struct TestCmul(pub u16, pub u16);
593        impl<F: FancyArithmetic> Circuit<F> for TestCmul {
594            type Input = F::Item;
595            type Output = F::Item;
596
597            fn execute(
598                &self,
599                backend: &mut F,
600                input: Self::Input,
601                _: &mut Channel,
602            ) -> Result<Self::Output> {
603                Ok(backend.cmul(&input, self.1))
604            }
605        }
606        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestCmul {
607            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
608                assert_eq!(inputs.len(), 1);
609                inputs[0].clone()
610            }
611
612            fn ninputs(&self) -> usize {
613                1
614            }
615
616            fn modulus(&self, _: usize) -> u16 {
617                self.0
618            }
619        }
620
621        /// Circuit for testing constant gates.
622        pub struct TestConstants(pub u16, pub u16);
623        impl<F: FancyArithmetic> Circuit<F> for TestConstants {
624            type Input = F::Item;
625            type Output = F::Item;
626
627            fn execute(
628                &self,
629                backend: &mut F,
630                input: Self::Input,
631                channel: &mut Channel,
632            ) -> Result<Self::Output> {
633                let constant = backend.constant(self.1, self.0, channel)?;
634                Ok(backend.add(&input, &constant))
635            }
636        }
637        impl<F: FancyArithmetic> CircuitInputMapper<F> for TestConstants {
638            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
639                assert_eq!(inputs.len(), 1);
640                inputs[0].clone()
641            }
642
643            fn ninputs(&self) -> usize {
644                1
645            }
646
647            fn modulus(&self, _: usize) -> u16 {
648                self.0
649            }
650        }
651    }
652
653    pub mod proj {
654        //! Circuits that test [`FancyProj`].
655
656        use crate::{
657            FancyProj,
658            circuit::{Circuit, CircuitInputMapper},
659        };
660        use swanky_channel::Channel;
661        use swanky_error::Result;
662
663        /// Circuit for testing [`FancyProj::proj`].
664        pub struct TestProj(pub u16);
665        impl<F: FancyProj> Circuit<F> for TestProj {
666            type Input = F::Item;
667            type Output = F::Item;
668
669            fn execute(
670                &self,
671                backend: &mut F,
672                input: Self::Input,
673                channel: &mut Channel,
674            ) -> Result<Self::Output> {
675                let tab = (0..self.0).map(|i| (i + 1) % self.0).collect();
676                backend.proj(&input, self.0, Some(tab), channel)
677            }
678        }
679        impl<F: FancyProj> CircuitInputMapper<F> for TestProj {
680            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
681                assert_eq!(inputs.len(), 1);
682                inputs[0].clone()
683            }
684
685            fn ninputs(&self) -> usize {
686                1
687            }
688
689            fn modulus(&self, _: usize) -> u16 {
690                self.0
691            }
692        }
693
694        /// Circuit for testing [`FancyProj::proj`] using a custom truth table.
695        pub struct TestProjRand(pub u16, pub Vec<u16>);
696        impl<F: FancyProj> Circuit<F> for TestProjRand {
697            type Input = F::Item;
698            type Output = F::Item;
699
700            fn execute(
701                &self,
702                backend: &mut F,
703                input: Self::Input,
704                channel: &mut Channel,
705            ) -> Result<Self::Output> {
706                backend.proj(&input, self.0, Some(self.1.clone()), channel)
707            }
708        }
709        impl<F: FancyProj> CircuitInputMapper<F> for TestProjRand {
710            fn map(&self, inputs: Vec<F::Item>) -> Self::Input {
711                assert_eq!(inputs.len(), 1);
712                inputs[0].clone()
713            }
714
715            fn ninputs(&self) -> usize {
716                1
717            }
718
719            fn modulus(&self, _: usize) -> u16 {
720                self.0
721            }
722        }
723    }
724}
725
726#[cfg(test)]
727mod fancy_arithmetic {
728    use crate::{
729        dummy::{Dummy, DummyVal},
730        test_circuits::arithmetic::{TestConstants, TestMulGate},
731        util::RngExt,
732    };
733    use rand::thread_rng;
734
735    #[test]
736    fn constants() {
737        let mut rng = thread_rng();
738        let q = rng.gen_modulus();
739        let c = rng.gen_u16() % q;
740        let circ = TestConstants(q, c);
741
742        for _ in 0..64 {
743            let x = DummyVal::rand(q, &mut rng);
744            let output = Dummy::eval(&circ, x).unwrap();
745            assert_eq!(output.val(), (x.val() + c) % q);
746        }
747    }
748
749    #[test]
750    fn arithmetic_half_gate() {
751        let mut rng = thread_rng();
752        let q = rng.gen_prime();
753        let c = TestMulGate(q);
754
755        for _ in 0..16 {
756            let x = DummyVal::rand(q, &mut rng);
757            let y = DummyVal::rand(q, &mut rng);
758            let output = Dummy::eval(&c, (x, y)).unwrap();
759            assert_eq!(output.val(), (x.val() * y.val()) % q);
760        }
761    }
762}