fancy_garbling/
wire.rs

1//! Wirelabels for use in garbled circuits.
2//!
3//! This module contains a [`WireLabel`] trait, alongside various instantiations
4//! of this trait. The [`WireLabel`] trait is the core underlying primitive used
5//! in garbled circuits, and represents an encoding of the value on any given
6//! wire of the circuit.
7
8use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12    U8x16,
13    array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24/// Hash a batch of wires, using the same tweak for each wire.
25pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27    ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29    let batch = wires.array_map(|x| x.to_repr());
30    TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33/// A marker trait indicating that the given [`WireLabel`] instantiation
34/// supports arithmetic operations.
35pub trait ArithmeticWire: Clone {}
36
37/// A trait that defines a wirelabel as used in garbled circuits.
38///
39/// At its core, a [`WireLabel`] is a way of encoding values, and operating on
40/// those encoded values.
41pub trait WireLabel:
42    Clone
43    + HasModulus
44    + core::ops::Add<Output = Self>
45    + core::ops::AddAssign
46    + core::ops::Sub<Output = Self>
47    + core::ops::SubAssign
48    + core::ops::Neg<Output = Self>
49    + core::ops::Mul<u16, Output = Self>
50    + core::ops::MulAssign<u16>
51{
52    /// The underlying digits encoded by the [`WireLabel`].
53    fn digits(&self) -> Vec<u16>;
54
55    /// Converts a [`WireLabel`] into its [`U8x16`] representation.
56    fn to_repr(&self) -> U8x16;
57
58    /// The color digit of the wire.
59    fn color(&self) -> u16;
60
61    /// Converts a [`U8x16`] into its [`WireLabel`] representation, based on the
62    /// modulus `q`.
63    ///
64    /// # Panics
65    /// This panics if `q` does not align with the modulus supported by the
66    /// [`WireLabel`].
67    fn from_repr(inp: U8x16, q: u16) -> Self;
68
69    /// A random [`WireLabel`] `mod q`, with the first digit set to `1`.
70    ///
71    /// # Panics
72    /// This panics if `q` does not align with the modulus supported by the
73    /// [`WireLabel`].
74    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
75
76    /// A random [`WireLabel`] `mod q`.
77    ///
78    /// # Panics
79    /// This panics if `q` does not align with the modulus supported by the
80    /// [`WireLabel`].
81    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
82
83    /// Converts a hashed block into a valid wire of the given modulus `q`.
84    ///
85    /// This is useful when separately using [`hash_wires`] to hash a set of
86    /// wires in one shot for efficiency reasons.
87    ///
88    /// # Panics
89    /// This panics if `q` does not align with the modulus supported by the
90    /// [`WireLabel`].
91    fn hash_to_mod(hash: U8x16, q: u16) -> Self;
92
93    /// Computes the hash of this [`WireLabel`], converting the result back into
94    /// a [`WireLabel`] based on the modulus `q`.
95    ///
96    /// This is equivalent to `WireLabel::hash_to_mod(self.hash(tweak), q)`, and
97    /// is useful when stringing together a sequence of operations on a
98    /// [`WireLabel`].
99    ///
100    /// # Panics
101    /// This panics if `q` does not align with the modulus supported by the
102    /// [`WireLabel`].
103    fn hashback(&self, tweak: u128, q: u16) -> Self {
104        let hash = self.hash(tweak);
105        Self::hash_to_mod(hash, q)
106    }
107
108    /// Computes the hash of the [`WireLabel`].
109    fn hash(&self, tweak: u128) -> U8x16 {
110        TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
111    }
112
113    /// Computes a [`WireLabel`] for `x % q`, returning both the zero
114    /// [`WireLabel`] as well as the [`WireLabel`] for `x % q`.
115    fn constant<RNG: CryptoRng + RngCore>(
116        x: u16,
117        q: u16,
118        delta: &Self,
119        rng: &mut RNG,
120    ) -> (Self, Self) {
121        let zero = Self::rand(rng, q);
122        let wire = zero.clone() + delta.clone() * x;
123        (zero, wire)
124    }
125}
126
127#[derive(Debug, Clone, PartialEq)]
128#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129/// A [`WireLabel`] that supports all possible moduli
130pub enum AllWire {
131    /// A `mod 2` [`WireLabel`].
132    Mod2(WireMod2),
133    /// A `mod 3` [`WireLabel`].
134    Mod3(WireMod3),
135    /// A `mod q` [`WireLabel`], where `3 < q < 2^16`.
136    ModN(WireModQ),
137}
138
139impl HasModulus for AllWire {
140    fn modulus(&self) -> u16 {
141        match &self {
142            AllWire::Mod2(x) => x.modulus(),
143            AllWire::Mod3(x) => x.modulus(),
144            AllWire::ModN(x) => x.modulus(),
145        }
146    }
147}
148
149impl core::ops::Add for AllWire {
150    type Output = Self;
151
152    fn add(self, rhs: Self) -> Self::Output {
153        let (p, q) = (self.modulus(), rhs.modulus());
154        match (self, rhs) {
155            (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
156            (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
157            (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
158            _ => panic!("unequal moduli: {p} != {q}"),
159        }
160    }
161}
162
163impl core::ops::AddAssign for AllWire {
164    fn add_assign(&mut self, rhs: Self) {
165        let (p, q) = (self.modulus(), rhs.modulus());
166        match (self, rhs) {
167            (Self::Mod2(x), Self::Mod2(y)) => *x += y,
168            (Self::Mod3(x), Self::Mod3(y)) => *x += y,
169            (Self::ModN(x), Self::ModN(y)) => *x += y,
170            _ => panic!("unequal moduli: {p} != {q}"),
171        }
172    }
173}
174
175impl core::ops::Sub for AllWire {
176    type Output = Self;
177
178    fn sub(self, rhs: Self) -> Self::Output {
179        self + -rhs
180    }
181}
182
183impl core::ops::SubAssign for AllWire {
184    fn sub_assign(&mut self, rhs: Self) {
185        *self = self.clone() - rhs;
186    }
187}
188
189impl core::ops::Neg for AllWire {
190    type Output = Self;
191
192    fn neg(self) -> Self::Output {
193        match self {
194            Self::Mod2(x) => Self::Mod2(-x),
195            Self::Mod3(x) => Self::Mod3(-x),
196            Self::ModN(x) => Self::ModN(-x),
197        }
198    }
199}
200
201impl core::ops::Mul<u16> for AllWire {
202    type Output = Self;
203
204    fn mul(self, rhs: u16) -> Self::Output {
205        match self {
206            Self::Mod2(x) => Self::Mod2(x * rhs),
207            Self::Mod3(x) => Self::Mod3(x * rhs),
208            Self::ModN(x) => Self::ModN(x * rhs),
209        }
210    }
211}
212
213impl core::ops::MulAssign<u16> for AllWire {
214    fn mul_assign(&mut self, rhs: u16) {
215        match self {
216            Self::Mod2(x) => {
217                *x *= rhs;
218            }
219            Self::Mod3(x) => {
220                *x *= rhs;
221            }
222            Self::ModN(x) => {
223                *x *= rhs;
224            }
225        };
226    }
227}
228
229impl WireLabel for AllWire {
230    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
231        match q {
232            2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
233            3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
234            _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
235        }
236    }
237
238    fn digits(&self) -> Vec<u16> {
239        match &self {
240            AllWire::Mod2(x) => x.digits(),
241            AllWire::Mod3(x) => x.digits(),
242            AllWire::ModN(x) => x.digits(),
243        }
244    }
245
246    fn to_repr(&self) -> U8x16 {
247        match &self {
248            AllWire::Mod2(x) => x.to_repr(),
249            AllWire::Mod3(x) => x.to_repr(),
250            AllWire::ModN(x) => x.to_repr(),
251        }
252    }
253    fn color(&self) -> u16 {
254        match &self {
255            AllWire::Mod2(x) => x.color(),
256            AllWire::Mod3(x) => x.color(),
257            AllWire::ModN(x) => x.color(),
258        }
259    }
260    fn from_repr(inp: U8x16, q: u16) -> Self {
261        match q {
262            2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
263            3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
264            _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
265        }
266    }
267
268    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
269        match q {
270            2 => AllWire::Mod2(WireMod2::rand(rng, q)),
271            3 => AllWire::Mod3(WireMod3::rand(rng, q)),
272            _ => AllWire::ModN(WireModQ::rand(rng, q)),
273        }
274    }
275
276    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
277        if q == 3 {
278            AllWire::Mod3(WireMod3::encode_block_mod3(hash))
279        } else {
280            Self::from_repr(hash, q)
281        }
282    }
283}
284fn _unrank(inp: u128, q: u16) -> Vec<u16> {
285    let mut x = inp;
286    let ndigits = util::digits_per_u128(q);
287    let npaths_tab = npaths_tab::lookup(q);
288    x %= npaths_tab[ndigits - 1] * q as u128;
289
290    let mut ds = vec![0; ndigits];
291    for i in (0..ndigits).rev() {
292        let npaths = npaths_tab[i];
293
294        if q <= 23 {
295            // linear search
296            let mut acc = 0;
297            for j in 0..q {
298                acc += npaths;
299                if acc > x {
300                    x -= acc - npaths;
301                    ds[i] = j;
302                    break;
303                }
304            }
305        } else {
306            // naive division
307            let d = x / npaths;
308            ds[i] = d as u16;
309            x -= d * npaths;
310        }
311        // } else {
312        //     // binary search
313        //     let mut low = 0;
314        //     let mut high = q;
315        //     loop {
316        //         let cur = (low + high) / 2;
317        //         let l = npaths * cur as u128;
318        //         let r = npaths * (cur as u128 + 1);
319        //         if x >= l && x < r {
320        //             x -= l;
321        //             ds[i] = cur;
322        //             break;
323        //         }
324        //         if x < l {
325        //             high = cur;
326        //         } else {
327        //             // x >= r
328        //             low = cur;
329        //         }
330        //     }
331        // }
332    }
333    ds
334}
335
336impl ArithmeticWire for AllWire {}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::util::RngExt;
342    use itertools::Itertools;
343    use rand::thread_rng;
344
345    #[test]
346    fn packing() {
347        let rng = &mut thread_rng();
348        for q in 2..256 {
349            for _ in 0..1000 {
350                let w = AllWire::rand(rng, q);
351                assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
352            }
353        }
354    }
355
356    #[test]
357    fn base_conversion_lookup_method() {
358        let rng = &mut thread_rng();
359        for _ in 0..1000 {
360            let q = 5 + (rng.gen_u16() % 110);
361            let x = rng.gen_u128();
362            let w = AllWire::from_repr(U8x16::from(x), q);
363            let should_be = util::as_base_q_u128(x, q);
364            assert_eq!(w.digits(), should_be, "x={} q={}", x, q);
365        }
366    }
367
368    #[test]
369    fn hash() {
370        let mut rng = thread_rng();
371        for _ in 0..100 {
372            let q = 2 + (rng.gen_u16() % 110);
373            let x = AllWire::rand(&mut rng, q);
374            let y = x.hashback(1u128, q);
375            assert!(x != y);
376            match y {
377                AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
378                AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
379                AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
380            }
381        }
382    }
383
384    #[test]
385    fn negation() {
386        let rng = &mut thread_rng();
387        for _ in 0..1000 {
388            let q = rng.gen_modulus();
389            let x = AllWire::rand(rng, q);
390            let xneg = -x.clone();
391            if q != 2 {
392                assert!(x != xneg);
393            }
394            let y = -xneg;
395            assert_eq!(x, y);
396        }
397    }
398
399    #[test]
400    #[allow(clippy::erasing_op)]
401    fn arithmetic() {
402        let mut rng = thread_rng();
403        for _ in 0..1024 {
404            let q = rng.gen_modulus();
405            let x = AllWire::rand(&mut rng, q);
406            let y = AllWire::rand(&mut rng, q);
407            assert_eq!(x.clone() * 0, x.clone() - x.clone());
408            assert_eq!(x.clone() * q, x.clone() - x.clone());
409            assert_eq!(x.clone() + x.clone(), x.clone() * 2);
410            assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
411            assert_eq!(-(-x.clone()), x);
412            if q == 2 {
413                assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
414            } else {
415                assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
416                assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
417            }
418            let mut w = x.clone();
419            let z = w.clone() + y.clone();
420            w += y;
421            assert_eq!(w, z);
422
423            w = x.clone();
424            w *= 2;
425            assert_eq!(x.clone() + x.clone(), w);
426
427            w = x.clone();
428            w = -w;
429            assert_eq!(-x, w);
430        }
431    }
432
433    #[test]
434    fn ndigits_correct() {
435        let mut rng = thread_rng();
436        for _ in 0..1024 {
437            let q = rng.gen_modulus();
438            let x = AllWire::rand(&mut rng, q);
439            assert_eq!(x.digits().len(), util::digits_per_u128(q));
440        }
441    }
442
443    #[test]
444    fn parallel_hash() {
445        let n = 1000;
446        let mut rng = thread_rng();
447        let q = rng.gen_modulus();
448        let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
449
450        let mut handles = Vec::new();
451        for w in ws.iter() {
452            let w_ = w.clone();
453            let h = std::thread::spawn(move || w_.hash(0u128));
454            handles.push(h);
455        }
456        let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
457
458        let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
459
460        assert_eq!(hashes, should_be);
461    }
462
463    #[cfg(feature = "serde")]
464    #[test]
465    fn test_serialize_allwire() {
466        let mut rng = thread_rng();
467        for q in 2..16 {
468            let w = AllWire::rand(&mut rng, q);
469            let serialized = serde_json::to_string(&w).unwrap();
470
471            let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
472
473            assert_eq!(w, deserialized);
474        }
475    }
476}