Skip to main content

fancy_garbling/
wire.rs

1//! Wirelabels for use in garbled circuits.
2//!
3//! This module contains a [`WireLabel`] trait, alongside various instantiations
4//! of this trait. The [`WireLabel`] trait is the core underlying primitive used
5//! in garbled circuits, and represents an encoding of the value on any given
6//! wire of the circuit.
7
8use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12    U8x16,
13    array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24/// Hash a batch of wires, using the same tweak for each wire.
25pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27    ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29    let batch = wires.array_map(|x| x.to_repr());
30    TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33/// A marker trait indicating that the given [`WireLabel`] instantiation
34/// supports arithmetic operations.
35pub trait ArithmeticWire: Clone {}
36
37/// A trait that defines a wirelabel as used in garbled circuits.
38///
39/// At its core, a [`WireLabel`] is a way of encoding values, and operating on
40/// those encoded values.
41pub trait WireLabel:
42    Clone
43    + HasModulus
44    + core::ops::Add<Output = Self>
45    + core::ops::AddAssign
46    + core::ops::Sub<Output = Self>
47    + core::ops::SubAssign
48    + core::ops::Neg<Output = Self>
49    + core::ops::Mul<u16, Output = Self>
50    + core::ops::MulAssign<u16>
51{
52    /// Converts a [`WireLabel`] into its [`U8x16`] representation.
53    fn to_repr(&self) -> U8x16;
54
55    /// The color digit of the wire.
56    fn color(&self) -> u16;
57
58    /// Converts a [`U8x16`] into its [`WireLabel`] representation, based on the
59    /// modulus `q`.
60    ///
61    /// # Panics
62    /// This panics if `q` does not align with the modulus supported by the
63    /// [`WireLabel`].
64    fn from_repr(inp: U8x16, q: u16) -> Self;
65
66    /// A random [`WireLabel`] `mod q`, with the first digit set to `1`.
67    ///
68    /// # Panics
69    /// This panics if `q` does not align with the modulus supported by the
70    /// [`WireLabel`].
71    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
72
73    /// A random [`WireLabel`] `mod q`.
74    ///
75    /// # Panics
76    /// This panics if `q` does not align with the modulus supported by the
77    /// [`WireLabel`].
78    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
79
80    /// Converts a hashed block into a valid wire of the given modulus `q`.
81    ///
82    /// This is useful when separately using [`hash_wires`] to hash a set of
83    /// wires in one shot for efficiency reasons.
84    ///
85    /// # Panics
86    /// This panics if `q` does not align with the modulus supported by the
87    /// [`WireLabel`].
88    fn hash_to_mod(hash: U8x16, q: u16) -> Self;
89
90    /// Computes the hash of this [`WireLabel`], converting the result back into
91    /// a [`WireLabel`] based on the modulus `q`.
92    ///
93    /// This is equivalent to `WireLabel::hash_to_mod(self.hash(tweak), q)`, and
94    /// is useful when stringing together a sequence of operations on a
95    /// [`WireLabel`].
96    ///
97    /// # Panics
98    /// This panics if `q` does not align with the modulus supported by the
99    /// [`WireLabel`].
100    fn hashback(&self, tweak: u128, q: u16) -> Self {
101        let hash = self.hash(tweak);
102        Self::hash_to_mod(hash, q)
103    }
104
105    /// Computes the hash of the [`WireLabel`].
106    fn hash(&self, tweak: u128) -> U8x16 {
107        TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
108    }
109
110    /// Computes a [`WireLabel`] for `x % q`, returning both the zero
111    /// [`WireLabel`] as well as the [`WireLabel`] for `x % q`.
112    fn constant<RNG: CryptoRng + RngCore>(
113        x: u16,
114        q: u16,
115        delta: &Self,
116        rng: &mut RNG,
117    ) -> (Self, Self) {
118        let zero = Self::rand(rng, q);
119        let wire = zero.clone() + delta.clone() * x;
120        (zero, wire)
121    }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126/// A [`WireLabel`] that supports all possible moduli
127pub enum AllWire {
128    /// A `mod 2` [`WireLabel`].
129    Mod2(WireMod2),
130    /// A `mod 3` [`WireLabel`].
131    Mod3(WireMod3),
132    /// A `mod q` [`WireLabel`], where `3 < q < 2^16`.
133    ModN(WireModQ),
134}
135
136impl HasModulus for AllWire {
137    fn modulus(&self) -> u16 {
138        match &self {
139            AllWire::Mod2(x) => x.modulus(),
140            AllWire::Mod3(x) => x.modulus(),
141            AllWire::ModN(x) => x.modulus(),
142        }
143    }
144}
145
146impl core::ops::Add for AllWire {
147    type Output = Self;
148
149    fn add(self, rhs: Self) -> Self::Output {
150        let (p, q) = (self.modulus(), rhs.modulus());
151        match (self, rhs) {
152            (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
153            (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
154            (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
155            _ => panic!("unequal moduli: {p} != {q}"),
156        }
157    }
158}
159
160impl core::ops::AddAssign for AllWire {
161    fn add_assign(&mut self, rhs: Self) {
162        let (p, q) = (self.modulus(), rhs.modulus());
163        match (self, rhs) {
164            (Self::Mod2(x), Self::Mod2(y)) => *x += y,
165            (Self::Mod3(x), Self::Mod3(y)) => *x += y,
166            (Self::ModN(x), Self::ModN(y)) => *x += y,
167            _ => panic!("unequal moduli: {p} != {q}"),
168        }
169    }
170}
171
172impl core::ops::Sub for AllWire {
173    type Output = Self;
174
175    fn sub(self, rhs: Self) -> Self::Output {
176        self + -rhs
177    }
178}
179
180impl core::ops::SubAssign for AllWire {
181    fn sub_assign(&mut self, rhs: Self) {
182        *self = self.clone() - rhs;
183    }
184}
185
186impl core::ops::Neg for AllWire {
187    type Output = Self;
188
189    fn neg(self) -> Self::Output {
190        match self {
191            Self::Mod2(x) => Self::Mod2(-x),
192            Self::Mod3(x) => Self::Mod3(-x),
193            Self::ModN(x) => Self::ModN(-x),
194        }
195    }
196}
197
198impl core::ops::Mul<u16> for AllWire {
199    type Output = Self;
200
201    fn mul(self, rhs: u16) -> Self::Output {
202        match self {
203            Self::Mod2(x) => Self::Mod2(x * rhs),
204            Self::Mod3(x) => Self::Mod3(x * rhs),
205            Self::ModN(x) => Self::ModN(x * rhs),
206        }
207    }
208}
209
210impl core::ops::MulAssign<u16> for AllWire {
211    fn mul_assign(&mut self, rhs: u16) {
212        match self {
213            Self::Mod2(x) => {
214                *x *= rhs;
215            }
216            Self::Mod3(x) => {
217                *x *= rhs;
218            }
219            Self::ModN(x) => {
220                *x *= rhs;
221            }
222        };
223    }
224}
225
226impl WireLabel for AllWire {
227    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
228        match q {
229            2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
230            3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
231            _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
232        }
233    }
234
235    fn to_repr(&self) -> U8x16 {
236        match &self {
237            AllWire::Mod2(x) => x.to_repr(),
238            AllWire::Mod3(x) => x.to_repr(),
239            AllWire::ModN(x) => x.to_repr(),
240        }
241    }
242    fn color(&self) -> u16 {
243        match &self {
244            AllWire::Mod2(x) => x.color(),
245            AllWire::Mod3(x) => x.color(),
246            AllWire::ModN(x) => x.color(),
247        }
248    }
249    fn from_repr(inp: U8x16, q: u16) -> Self {
250        match q {
251            2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
252            3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
253            _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
254        }
255    }
256
257    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
258        match q {
259            2 => AllWire::Mod2(WireMod2::rand(rng, q)),
260            3 => AllWire::Mod3(WireMod3::rand(rng, q)),
261            _ => AllWire::ModN(WireModQ::rand(rng, q)),
262        }
263    }
264
265    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
266        if q == 3 {
267            AllWire::Mod3(WireMod3::encode_block_mod3(hash))
268        } else {
269            Self::from_repr(hash, q)
270        }
271    }
272}
273fn _unrank(inp: u128, q: u16) -> Vec<u16> {
274    let mut x = inp;
275    let ndigits = util::digits_per_u128(q);
276    let npaths_tab = npaths_tab::lookup(q);
277    x %= npaths_tab[ndigits - 1] * q as u128;
278
279    let mut ds = vec![0; ndigits];
280    for i in (0..ndigits).rev() {
281        let npaths = npaths_tab[i];
282
283        if q <= 23 {
284            // linear search
285            let mut acc = 0;
286            for j in 0..q {
287                acc += npaths;
288                if acc > x {
289                    x -= acc - npaths;
290                    ds[i] = j;
291                    break;
292                }
293            }
294        } else {
295            // naive division
296            let d = x / npaths;
297            ds[i] = d as u16;
298            x -= d * npaths;
299        }
300        // } else {
301        //     // binary search
302        //     let mut low = 0;
303        //     let mut high = q;
304        //     loop {
305        //         let cur = (low + high) / 2;
306        //         let l = npaths * cur as u128;
307        //         let r = npaths * (cur as u128 + 1);
308        //         if x >= l && x < r {
309        //             x -= l;
310        //             ds[i] = cur;
311        //             break;
312        //         }
313        //         if x < l {
314        //             high = cur;
315        //         } else {
316        //             // x >= r
317        //             low = cur;
318        //         }
319        //     }
320        // }
321    }
322    ds
323}
324
325impl ArithmeticWire for AllWire {}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::util::RngExt;
331    use itertools::Itertools;
332    use rand::thread_rng;
333
334    #[test]
335    fn packing() {
336        let rng = &mut thread_rng();
337        for q in 2..256 {
338            for _ in 0..1000 {
339                let w = AllWire::rand(rng, q);
340                assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
341            }
342        }
343    }
344
345    #[test]
346    fn base_conversion_lookup_method() {
347        let rng = &mut thread_rng();
348        for _ in 0..1000 {
349            let q = 5 + (rng.gen_u16() % 110);
350            let x = rng.gen_u128();
351            let w = WireModQ::from_repr(U8x16::from(x), q);
352            let should_be = util::as_base_q_u128(x, q);
353            assert_eq!(w.ds, should_be, "x={} q={}", x, q);
354        }
355    }
356
357    #[test]
358    fn hash() {
359        let mut rng = thread_rng();
360        for _ in 0..100 {
361            let q = 2 + (rng.gen_u16() % 110);
362            let x = AllWire::rand(&mut rng, q);
363            let y = x.hashback(1u128, q);
364            assert!(x != y);
365            match y {
366                AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
367                AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
368                AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
369            }
370        }
371    }
372
373    #[test]
374    fn negation() {
375        let rng = &mut thread_rng();
376        for _ in 0..1000 {
377            let q = rng.gen_modulus();
378            let x = AllWire::rand(rng, q);
379            let xneg = -x.clone();
380            if q != 2 {
381                assert!(x != xneg);
382            }
383            let y = -xneg;
384            assert_eq!(x, y);
385        }
386    }
387
388    #[test]
389    #[allow(clippy::erasing_op)]
390    fn arithmetic() {
391        let mut rng = thread_rng();
392        for _ in 0..1024 {
393            let q = rng.gen_modulus();
394            let x = AllWire::rand(&mut rng, q);
395            let y = AllWire::rand(&mut rng, q);
396            assert_eq!(x.clone() * 0, x.clone() - x.clone());
397            assert_eq!(x.clone() * q, x.clone() - x.clone());
398            assert_eq!(x.clone() + x.clone(), x.clone() * 2);
399            assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
400            assert_eq!(-(-x.clone()), x);
401            if q == 2 {
402                assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
403            } else {
404                assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
405                assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
406            }
407            let mut w = x.clone();
408            let z = w.clone() + y.clone();
409            w += y;
410            assert_eq!(w, z);
411
412            w = x.clone();
413            w *= 2;
414            assert_eq!(x.clone() + x.clone(), w);
415
416            w = x.clone();
417            w = -w;
418            assert_eq!(-x, w);
419        }
420    }
421
422    #[test]
423    fn ndigits_correct() {
424        let mut rng = thread_rng();
425        for _ in 0..1024 {
426            let q = rng.gen_modulus();
427            let x = WireModQ::rand(&mut rng, q);
428            assert_eq!(x.ds.len(), util::digits_per_u128(q));
429        }
430    }
431
432    #[test]
433    fn parallel_hash() {
434        let n = 1000;
435        let mut rng = thread_rng();
436        let q = rng.gen_modulus();
437        let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
438
439        let mut handles = Vec::new();
440        for w in ws.iter() {
441            let w_ = w.clone();
442            let h = std::thread::spawn(move || w_.hash(0u128));
443            handles.push(h);
444        }
445        let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
446
447        let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
448
449        assert_eq!(hashes, should_be);
450    }
451
452    #[cfg(feature = "serde")]
453    #[test]
454    fn test_serialize_allwire() {
455        let mut rng = thread_rng();
456        for q in 2..16 {
457            let w = AllWire::rand(&mut rng, q);
458            let serialized = serde_json::to_string(&w).unwrap();
459
460            let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
461
462            assert_eq!(w, deserialized);
463        }
464    }
465}