fancy_garbling/
wire.rs

1//! Wirelabels for use in garbled circuits.
2//!
3//! This module contains a [`WireLabel`] trait, alongside various instantiations
4//! of this trait. The [`WireLabel`] trait is the core underlying primitive used
5//! in garbled circuits, and represents an encoding of the value on any given
6//! wire of the circuit.
7
8use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_aes_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12    U8x16,
13    array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24/// Hash a batch of wires, using the same tweak for each wire.
25pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27    ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29    let batch = wires.array_map(|x| x.to_repr());
30    TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33/// A marker trait indicating that the given [`WireLabel`] instantiation
34/// supports arithmetic operations.
35pub trait ArithmeticWire: Clone {}
36
37/// A trait that defines a wirelabel as used in garbled circuits.
38///
39/// At its core, a [`WireLabel`] is a way of encoding values, and operating on
40/// those encoded values.
41pub trait WireLabel:
42    Clone
43    + HasModulus
44    + core::ops::Add<Output = Self>
45    + core::ops::AddAssign
46    + core::ops::Sub<Output = Self>
47    + core::ops::SubAssign
48    + core::ops::Neg<Output = Self>
49    + core::ops::Mul<u16, Output = Self>
50    + core::ops::MulAssign<u16>
51{
52    /// The underlying digits encoded by the [`WireLabel`].
53    fn digits(&self) -> Vec<u16>;
54
55    /// Converts a [`WireLabel`] into its [`U8x16`] representation.
56    fn to_repr(&self) -> U8x16;
57
58    /// The color digit of the wire.
59    fn color(&self) -> u16;
60
61    /// Converts a [`U8x16`] into its [`WireLabel`] representation, based on the
62    /// modulus `q`.
63    ///
64    /// # Panics
65    /// This panics if `q` does not align with the modulus supported by the
66    /// [`WireLabel`].
67    fn from_repr(inp: U8x16, q: u16) -> Self;
68
69    /// The zero [`WireLabel`], based on the modulus `q`.
70    ///
71    /// # Panics
72    /// This panics if `q` does not align with the modulus supported by the
73    /// [`WireLabel`].
74    // TODO: This is deceiving. It is _not_ a zero wirelabel as it is called in
75    // the literature, but rather simply a zero _value_. This could lead to bugs
76    // and should be changed!
77    fn zero(q: u16) -> Self;
78
79    /// A random [`WireLabel`] `mod q`, with the first digit set to `1`.
80    ///
81    /// # Panics
82    /// This panics if `q` does not align with the modulus supported by the
83    /// [`WireLabel`].
84    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
85
86    /// A random [`WireLabel`] `mod q`.
87    ///
88    /// # Panics
89    /// This panics if `q` does not align with the modulus supported by the
90    /// [`WireLabel`].
91    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
92
93    /// Converts a hashed block into a valid wire of the given modulus `q`.
94    ///
95    /// This is useful when separately using [`hash_wires`] to hash a set of
96    /// wires in one shot for efficiency reasons.
97    ///
98    /// # Panics
99    /// This panics if `q` does not align with the modulus supported by the
100    /// [`WireLabel`].
101    fn hash_to_mod(hash: U8x16, q: u16) -> Self;
102
103    /// Computes the hash of this [`WireLabel`], converting the result back into
104    /// a [`WireLabel`] based on the modulus `q`.
105    ///
106    /// This is equivalent to `WireLabel::hash_to_mod(self.hash(tweak), q)`, and
107    /// is useful when stringing together a sequence of operations on a
108    /// [`WireLabel`].
109    ///
110    /// # Panics
111    /// This panics if `q` does not align with the modulus supported by the
112    /// [`WireLabel`].
113    fn hashback(&self, tweak: u128, q: u16) -> Self {
114        let hash = self.hash(tweak);
115        Self::hash_to_mod(hash, q)
116    }
117
118    /// Computes the hash of the [`WireLabel`].
119    #[inline(never)]
120    fn hash(&self, tweak: u128) -> U8x16 {
121        TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
122    }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127/// A [`WireLabel`] that supports all possible moduli
128pub enum AllWire {
129    /// A `mod 2` [`WireLabel`].
130    Mod2(WireMod2),
131    /// A `mod 3` [`WireLabel`].
132    Mod3(WireMod3),
133    /// A `mod q` [`WireLabel`], where `3 < q < 2^16`.
134    ModN(WireModQ),
135}
136
137impl HasModulus for AllWire {
138    fn modulus(&self) -> u16 {
139        match &self {
140            AllWire::Mod2(x) => x.modulus(),
141            AllWire::Mod3(x) => x.modulus(),
142            AllWire::ModN(x) => x.modulus(),
143        }
144    }
145}
146
147impl core::ops::Add for AllWire {
148    type Output = Self;
149
150    fn add(self, rhs: Self) -> Self::Output {
151        let (p, q) = (self.modulus(), rhs.modulus());
152        match (self, rhs) {
153            (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
154            (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
155            (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
156            _ => panic!("unequal moduli: {p} != {q}"),
157        }
158    }
159}
160
161impl core::ops::AddAssign for AllWire {
162    fn add_assign(&mut self, rhs: Self) {
163        let (p, q) = (self.modulus(), rhs.modulus());
164        match (self, rhs) {
165            (Self::Mod2(x), Self::Mod2(y)) => *x += y,
166            (Self::Mod3(x), Self::Mod3(y)) => *x += y,
167            (Self::ModN(x), Self::ModN(y)) => *x += y,
168            _ => panic!("unequal moduli: {p} != {q}"),
169        }
170    }
171}
172
173impl core::ops::Sub for AllWire {
174    type Output = Self;
175
176    fn sub(self, rhs: Self) -> Self::Output {
177        self + -rhs
178    }
179}
180
181impl core::ops::SubAssign for AllWire {
182    fn sub_assign(&mut self, rhs: Self) {
183        *self = self.clone() - rhs;
184    }
185}
186
187impl core::ops::Neg for AllWire {
188    type Output = Self;
189
190    fn neg(self) -> Self::Output {
191        match self {
192            Self::Mod2(x) => Self::Mod2(-x),
193            Self::Mod3(x) => Self::Mod3(-x),
194            Self::ModN(x) => Self::ModN(-x),
195        }
196    }
197}
198
199impl core::ops::Mul<u16> for AllWire {
200    type Output = Self;
201
202    fn mul(self, rhs: u16) -> Self::Output {
203        match self {
204            Self::Mod2(x) => Self::Mod2(x * rhs),
205            Self::Mod3(x) => Self::Mod3(x * rhs),
206            Self::ModN(x) => Self::ModN(x * rhs),
207        }
208    }
209}
210
211impl core::ops::MulAssign<u16> for AllWire {
212    fn mul_assign(&mut self, rhs: u16) {
213        match self {
214            Self::Mod2(x) => {
215                *x *= rhs;
216            }
217            Self::Mod3(x) => {
218                *x *= rhs;
219            }
220            Self::ModN(x) => {
221                *x *= rhs;
222            }
223        };
224    }
225}
226
227impl WireLabel for AllWire {
228    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
229        match q {
230            2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
231            3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
232            _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
233        }
234    }
235
236    fn digits(&self) -> Vec<u16> {
237        match &self {
238            AllWire::Mod2(x) => x.digits(),
239            AllWire::Mod3(x) => x.digits(),
240            AllWire::ModN(x) => x.digits(),
241        }
242    }
243
244    fn to_repr(&self) -> U8x16 {
245        match &self {
246            AllWire::Mod2(x) => x.to_repr(),
247            AllWire::Mod3(x) => x.to_repr(),
248            AllWire::ModN(x) => x.to_repr(),
249        }
250    }
251    fn color(&self) -> u16 {
252        match &self {
253            AllWire::Mod2(x) => x.color(),
254            AllWire::Mod3(x) => x.color(),
255            AllWire::ModN(x) => x.color(),
256        }
257    }
258    fn from_repr(inp: U8x16, q: u16) -> Self {
259        match q {
260            2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
261            3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
262            _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
263        }
264    }
265
266    fn zero(q: u16) -> Self {
267        match q {
268            2 => AllWire::Mod2(WireMod2::zero(q)),
269            3 => AllWire::Mod3(WireMod3::zero(q)),
270            _ => AllWire::ModN(WireModQ::zero(q)),
271        }
272    }
273
274    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
275        match q {
276            2 => AllWire::Mod2(WireMod2::rand(rng, q)),
277            3 => AllWire::Mod3(WireMod3::rand(rng, q)),
278            _ => AllWire::ModN(WireModQ::rand(rng, q)),
279        }
280    }
281
282    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
283        if q == 3 {
284            AllWire::Mod3(WireMod3::encode_block_mod3(hash))
285        } else {
286            Self::from_repr(hash, q)
287        }
288    }
289}
290fn _unrank(inp: u128, q: u16) -> Vec<u16> {
291    let mut x = inp;
292    let ndigits = util::digits_per_u128(q);
293    let npaths_tab = npaths_tab::lookup(q);
294    x %= npaths_tab[ndigits - 1] * q as u128;
295
296    let mut ds = vec![0; ndigits];
297    for i in (0..ndigits).rev() {
298        let npaths = npaths_tab[i];
299
300        if q <= 23 {
301            // linear search
302            let mut acc = 0;
303            for j in 0..q {
304                acc += npaths;
305                if acc > x {
306                    x -= acc - npaths;
307                    ds[i] = j;
308                    break;
309                }
310            }
311        } else {
312            // naive division
313            let d = x / npaths;
314            ds[i] = d as u16;
315            x -= d * npaths;
316        }
317        // } else {
318        //     // binary search
319        //     let mut low = 0;
320        //     let mut high = q;
321        //     loop {
322        //         let cur = (low + high) / 2;
323        //         let l = npaths * cur as u128;
324        //         let r = npaths * (cur as u128 + 1);
325        //         if x >= l && x < r {
326        //             x -= l;
327        //             ds[i] = cur;
328        //             break;
329        //         }
330        //         if x < l {
331        //             high = cur;
332        //         } else {
333        //             // x >= r
334        //             low = cur;
335        //         }
336        //     }
337        // }
338    }
339    ds
340}
341
342impl ArithmeticWire for AllWire {}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::util::RngExt;
348    use itertools::Itertools;
349    use rand::thread_rng;
350
351    #[test]
352    fn packing() {
353        let rng = &mut thread_rng();
354        for q in 2..256 {
355            for _ in 0..1000 {
356                let w = AllWire::rand(rng, q);
357                assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
358            }
359        }
360    }
361
362    #[test]
363    fn base_conversion_lookup_method() {
364        let rng = &mut thread_rng();
365        for _ in 0..1000 {
366            let q = 5 + (rng.gen_u16() % 110);
367            let x = rng.gen_u128();
368            let w = AllWire::from_repr(U8x16::from(x), q);
369            let should_be = util::as_base_q_u128(x, q);
370            assert_eq!(w.digits(), should_be, "x={} q={}", x, q);
371        }
372    }
373
374    #[test]
375    fn hash() {
376        let mut rng = thread_rng();
377        for _ in 0..100 {
378            let q = 2 + (rng.gen_u16() % 110);
379            let x = AllWire::rand(&mut rng, q);
380            let y = x.hashback(1u128, q);
381            assert!(x != y);
382            match y {
383                AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
384                AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
385                AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
386            }
387        }
388    }
389
390    #[test]
391    fn negation() {
392        let rng = &mut thread_rng();
393        for _ in 0..1000 {
394            let q = rng.gen_modulus();
395            let x = AllWire::rand(rng, q);
396            let xneg = -x.clone();
397            if q != 2 {
398                assert!(x != xneg);
399            }
400            let y = -xneg;
401            assert_eq!(x, y);
402        }
403    }
404
405    #[test]
406    fn zero() {
407        let mut rng = thread_rng();
408        for _ in 0..1000 {
409            let q = 3 + (rng.gen_u16() % 110);
410            let z = AllWire::zero(q);
411            let ds = z.digits();
412            assert_eq!(ds, vec![0; ds.len()], "q={}", q);
413        }
414    }
415
416    #[test]
417    fn subzero() {
418        let mut rng = thread_rng();
419        for _ in 0..1000 {
420            let q = rng.gen_modulus();
421            let x = AllWire::rand(&mut rng, q);
422            let z = AllWire::zero(q);
423            assert_eq!(x.clone() - x, z);
424        }
425    }
426
427    #[test]
428    fn pluszero() {
429        let mut rng = thread_rng();
430        for _ in 0..1000 {
431            let q = rng.gen_modulus();
432            let x = AllWire::rand(&mut rng, q);
433            assert_eq!(x.clone() + AllWire::zero(q), x);
434        }
435    }
436
437    #[test]
438    #[allow(clippy::erasing_op)]
439    fn arithmetic() {
440        let mut rng = thread_rng();
441        for _ in 0..1024 {
442            let q = rng.gen_modulus();
443            let x = AllWire::rand(&mut rng, q);
444            let y = AllWire::rand(&mut rng, q);
445            assert_eq!(x.clone() * 0, AllWire::zero(q));
446            assert_eq!(x.clone() * q, AllWire::zero(q));
447            assert_eq!(x.clone() + x.clone(), x.clone() * 2);
448            assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
449            assert_eq!(-(-x.clone()), x);
450            if q == 2 {
451                assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
452            } else {
453                assert_eq!(x.clone() + -x.clone(), AllWire::zero(q), "q={}", q);
454                assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
455            }
456            let mut w = x.clone();
457            let z = w.clone() + y.clone();
458            w += y;
459            assert_eq!(w, z);
460
461            w = x.clone();
462            w *= 2;
463            assert_eq!(x.clone() + x.clone(), w);
464
465            w = x.clone();
466            w = -w;
467            assert_eq!(-x, w);
468        }
469    }
470
471    #[test]
472    fn ndigits_correct() {
473        let mut rng = thread_rng();
474        for _ in 0..1024 {
475            let q = rng.gen_modulus();
476            let x = AllWire::rand(&mut rng, q);
477            assert_eq!(x.digits().len(), util::digits_per_u128(q));
478        }
479    }
480
481    #[test]
482    fn parallel_hash() {
483        let n = 1000;
484        let mut rng = thread_rng();
485        let q = rng.gen_modulus();
486        let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
487
488        let mut handles = Vec::new();
489        for w in ws.iter() {
490            let w_ = w.clone();
491            let h = std::thread::spawn(move || w_.hash(0u128));
492            handles.push(h);
493        }
494        let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
495
496        let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
497
498        assert_eq!(hashes, should_be);
499    }
500
501    #[cfg(feature = "serde")]
502    #[test]
503    fn test_serialize_allwire() {
504        let mut rng = thread_rng();
505        for q in 2..16 {
506            let w = AllWire::rand(&mut rng, q);
507            let serialized = serde_json::to_string(&w).unwrap();
508
509            let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
510
511            assert_eq!(w, deserialized);
512        }
513    }
514}