Skip to main content

fancy_garbling/
wire.rs

1//! Wirelabels for use in garbled circuits.
2//!
3//! This module contains a [`WireLabel`] trait, alongside various instantiations
4//! of this trait. The [`WireLabel`] trait is the core underlying primitive used
5//! in garbled circuits, and represents an encoding of the value on any given
6//! wire of the circuit.
7
8use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12    U8x16,
13    array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24/// Hash a batch of wires, using the same tweak for each wire.
25pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27    ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29    let batch = wires.array_map(|x| x.to_repr());
30    TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33/// A marker trait indicating that the given [`WireLabel`] instantiation
34/// supports arithmetic operations.
35pub trait ArithmeticWire: Clone {}
36
37/// A trait that defines a wirelabel as used in garbled circuits.
38///
39/// At its core, a [`WireLabel`] is a way of encoding values, and operating on
40/// those encoded values.
41pub trait WireLabel:
42    Clone
43    + core::fmt::Debug
44    + HasModulus
45    + core::ops::Add<Output = Self>
46    + core::ops::AddAssign
47    + core::ops::Sub<Output = Self>
48    + core::ops::SubAssign
49    + core::ops::Neg<Output = Self>
50    + core::ops::Mul<u16, Output = Self>
51    + core::ops::MulAssign<u16>
52{
53    /// Converts a [`WireLabel`] into its [`U8x16`] representation.
54    fn to_repr(&self) -> U8x16;
55
56    /// The color digit of the wire.
57    fn color(&self) -> u16;
58
59    /// Converts a [`U8x16`] into its [`WireLabel`] representation, based on the
60    /// modulus `q`.
61    ///
62    /// # Panics
63    /// This panics if `q` does not align with the modulus supported by the
64    /// [`WireLabel`].
65    fn from_repr(inp: U8x16, q: u16) -> Self;
66
67    /// A random [`WireLabel`] `mod q`, with the first digit set to `1`.
68    ///
69    /// # Panics
70    /// This panics if `q` does not align with the modulus supported by the
71    /// [`WireLabel`].
72    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
73
74    /// A random [`WireLabel`] `mod q`.
75    ///
76    /// # Panics
77    /// This panics if `q` does not align with the modulus supported by the
78    /// [`WireLabel`].
79    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
80
81    /// Converts a hashed block into a valid wire of the given modulus `q`.
82    ///
83    /// This is useful when separately using [`hash_wires`] to hash a set of
84    /// wires in one shot for efficiency reasons.
85    ///
86    /// # Panics
87    /// This panics if `q` does not align with the modulus supported by the
88    /// [`WireLabel`].
89    fn hash_to_mod(hash: U8x16, q: u16) -> Self;
90
91    /// Computes the hash of this [`WireLabel`], converting the result back into
92    /// a [`WireLabel`] based on the modulus `q`.
93    ///
94    /// This is equivalent to `WireLabel::hash_to_mod(self.hash(tweak), q)`, and
95    /// is useful when stringing together a sequence of operations on a
96    /// [`WireLabel`].
97    ///
98    /// # Panics
99    /// This panics if `q` does not align with the modulus supported by the
100    /// [`WireLabel`].
101    fn hashback(&self, tweak: u128, q: u16) -> Self {
102        let hash = self.hash(tweak);
103        Self::hash_to_mod(hash, q)
104    }
105
106    /// Computes the hash of the [`WireLabel`].
107    fn hash(&self, tweak: u128) -> U8x16 {
108        TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
109    }
110
111    /// Computes a [`WireLabel`] for `x % q`, returning both the zero
112    /// [`WireLabel`] as well as the [`WireLabel`] for `x % q`.
113    fn constant<RNG: CryptoRng + RngCore>(
114        x: u16,
115        q: u16,
116        delta: &Self,
117        rng: &mut RNG,
118    ) -> (Self, Self) {
119        let zero = Self::rand(rng, q);
120        let wire = zero.clone() + delta.clone() * x;
121        (zero, wire)
122    }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127/// A [`WireLabel`] that supports all possible moduli
128pub enum AllWire {
129    /// A `mod 2` [`WireLabel`].
130    Mod2(WireMod2),
131    /// A `mod 3` [`WireLabel`].
132    Mod3(WireMod3),
133    /// A `mod q` [`WireLabel`], where `3 < q < 2^16`.
134    ModN(WireModQ),
135}
136
137impl HasModulus for AllWire {
138    fn modulus(&self) -> u16 {
139        match &self {
140            AllWire::Mod2(x) => x.modulus(),
141            AllWire::Mod3(x) => x.modulus(),
142            AllWire::ModN(x) => x.modulus(),
143        }
144    }
145}
146
147impl core::ops::Add for AllWire {
148    type Output = Self;
149
150    fn add(self, rhs: Self) -> Self::Output {
151        let (p, q) = (self.modulus(), rhs.modulus());
152        match (self, rhs) {
153            (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
154            (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
155            (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
156            _ => panic!("unequal moduli: {p} != {q}"),
157        }
158    }
159}
160
161impl core::ops::AddAssign for AllWire {
162    fn add_assign(&mut self, rhs: Self) {
163        let (p, q) = (self.modulus(), rhs.modulus());
164        match (self, rhs) {
165            (Self::Mod2(x), Self::Mod2(y)) => *x += y,
166            (Self::Mod3(x), Self::Mod3(y)) => *x += y,
167            (Self::ModN(x), Self::ModN(y)) => *x += y,
168            _ => panic!("unequal moduli: {p} != {q}"),
169        }
170    }
171}
172
173impl core::ops::Sub for AllWire {
174    type Output = Self;
175
176    fn sub(self, rhs: Self) -> Self::Output {
177        self + -rhs
178    }
179}
180
181impl core::ops::SubAssign for AllWire {
182    fn sub_assign(&mut self, rhs: Self) {
183        *self = self.clone() - rhs;
184    }
185}
186
187impl core::ops::Neg for AllWire {
188    type Output = Self;
189
190    fn neg(self) -> Self::Output {
191        match self {
192            Self::Mod2(x) => Self::Mod2(-x),
193            Self::Mod3(x) => Self::Mod3(-x),
194            Self::ModN(x) => Self::ModN(-x),
195        }
196    }
197}
198
199impl core::ops::Mul<u16> for AllWire {
200    type Output = Self;
201
202    fn mul(self, rhs: u16) -> Self::Output {
203        match self {
204            Self::Mod2(x) => Self::Mod2(x * rhs),
205            Self::Mod3(x) => Self::Mod3(x * rhs),
206            Self::ModN(x) => Self::ModN(x * rhs),
207        }
208    }
209}
210
211impl core::ops::MulAssign<u16> for AllWire {
212    fn mul_assign(&mut self, rhs: u16) {
213        match self {
214            Self::Mod2(x) => {
215                *x *= rhs;
216            }
217            Self::Mod3(x) => {
218                *x *= rhs;
219            }
220            Self::ModN(x) => {
221                *x *= rhs;
222            }
223        };
224    }
225}
226
227impl WireLabel for AllWire {
228    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
229        match q {
230            2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
231            3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
232            _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
233        }
234    }
235
236    fn to_repr(&self) -> U8x16 {
237        match &self {
238            AllWire::Mod2(x) => x.to_repr(),
239            AllWire::Mod3(x) => x.to_repr(),
240            AllWire::ModN(x) => x.to_repr(),
241        }
242    }
243    fn color(&self) -> u16 {
244        match &self {
245            AllWire::Mod2(x) => x.color(),
246            AllWire::Mod3(x) => x.color(),
247            AllWire::ModN(x) => x.color(),
248        }
249    }
250    fn from_repr(inp: U8x16, q: u16) -> Self {
251        match q {
252            2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
253            3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
254            _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
255        }
256    }
257
258    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
259        match q {
260            2 => AllWire::Mod2(WireMod2::rand(rng, q)),
261            3 => AllWire::Mod3(WireMod3::rand(rng, q)),
262            _ => AllWire::ModN(WireModQ::rand(rng, q)),
263        }
264    }
265
266    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
267        if q == 3 {
268            AllWire::Mod3(WireMod3::encode_block_mod3(hash))
269        } else {
270            Self::from_repr(hash, q)
271        }
272    }
273}
274fn _unrank(inp: u128, q: u16) -> Vec<u16> {
275    let mut x = inp;
276    let ndigits = util::digits_per_u128(q);
277    let npaths_tab = npaths_tab::lookup(q);
278    x %= npaths_tab[ndigits - 1] * q as u128;
279
280    let mut ds = vec![0; ndigits];
281    for i in (0..ndigits).rev() {
282        let npaths = npaths_tab[i];
283
284        if q <= 23 {
285            // linear search
286            let mut acc = 0;
287            for j in 0..q {
288                acc += npaths;
289                if acc > x {
290                    x -= acc - npaths;
291                    ds[i] = j;
292                    break;
293                }
294            }
295        } else {
296            // naive division
297            let d = x / npaths;
298            ds[i] = d as u16;
299            x -= d * npaths;
300        }
301        // } else {
302        //     // binary search
303        //     let mut low = 0;
304        //     let mut high = q;
305        //     loop {
306        //         let cur = (low + high) / 2;
307        //         let l = npaths * cur as u128;
308        //         let r = npaths * (cur as u128 + 1);
309        //         if x >= l && x < r {
310        //             x -= l;
311        //             ds[i] = cur;
312        //             break;
313        //         }
314        //         if x < l {
315        //             high = cur;
316        //         } else {
317        //             // x >= r
318        //             low = cur;
319        //         }
320        //     }
321        // }
322    }
323    ds
324}
325
326impl ArithmeticWire for AllWire {}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::util::RngExt;
332    use itertools::Itertools;
333    use rand::thread_rng;
334
335    #[test]
336    fn packing() {
337        let rng = &mut thread_rng();
338        for q in 2..256 {
339            for _ in 0..1000 {
340                let w = AllWire::rand(rng, q);
341                assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
342            }
343        }
344    }
345
346    #[test]
347    fn base_conversion_lookup_method() {
348        let rng = &mut thread_rng();
349        for _ in 0..1000 {
350            let q = 5 + (rng.gen_u16() % 110);
351            let x = rng.gen_u128();
352            let w = WireModQ::from_repr(U8x16::from(x), q);
353            let should_be = util::as_base_q_u128(x, q);
354            assert_eq!(w.ds, should_be, "x={} q={}", x, q);
355        }
356    }
357
358    #[test]
359    fn hash() {
360        let mut rng = thread_rng();
361        for _ in 0..100 {
362            let q = 2 + (rng.gen_u16() % 110);
363            let x = AllWire::rand(&mut rng, q);
364            let y = x.hashback(1u128, q);
365            assert!(x != y);
366            match y {
367                AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
368                AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
369                AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
370            }
371        }
372    }
373
374    #[test]
375    fn negation() {
376        let rng = &mut thread_rng();
377        for _ in 0..1000 {
378            let q = rng.gen_modulus();
379            let x = AllWire::rand(rng, q);
380            let xneg = -x.clone();
381            if q != 2 {
382                assert!(x != xneg);
383            }
384            let y = -xneg;
385            assert_eq!(x, y);
386        }
387    }
388
389    #[test]
390    #[allow(clippy::erasing_op)]
391    fn arithmetic() {
392        let mut rng = thread_rng();
393        for _ in 0..1024 {
394            let q = rng.gen_modulus();
395            let x = AllWire::rand(&mut rng, q);
396            let y = AllWire::rand(&mut rng, q);
397            assert_eq!(x.clone() * 0, x.clone() - x.clone());
398            assert_eq!(x.clone() * q, x.clone() - x.clone());
399            assert_eq!(x.clone() + x.clone(), x.clone() * 2);
400            assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
401            assert_eq!(-(-x.clone()), x);
402            if q == 2 {
403                assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
404            } else {
405                assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
406                assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
407            }
408            let mut w = x.clone();
409            let z = w.clone() + y.clone();
410            w += y;
411            assert_eq!(w, z);
412
413            w = x.clone();
414            w *= 2;
415            assert_eq!(x.clone() + x.clone(), w);
416
417            w = x.clone();
418            w = -w;
419            assert_eq!(-x, w);
420        }
421    }
422
423    #[test]
424    fn ndigits_correct() {
425        let mut rng = thread_rng();
426        for _ in 0..1024 {
427            let q = rng.gen_modulus();
428            let x = WireModQ::rand(&mut rng, q);
429            assert_eq!(x.ds.len(), util::digits_per_u128(q));
430        }
431    }
432
433    #[test]
434    fn parallel_hash() {
435        let n = 1000;
436        let mut rng = thread_rng();
437        let q = rng.gen_modulus();
438        let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
439
440        let mut handles = Vec::new();
441        for w in ws.iter() {
442            let w_ = w.clone();
443            let h = std::thread::spawn(move || w_.hash(0u128));
444            handles.push(h);
445        }
446        let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
447
448        let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
449
450        assert_eq!(hashes, should_be);
451    }
452
453    #[cfg(feature = "serde")]
454    #[test]
455    fn test_serialize_allwire() {
456        let mut rng = thread_rng();
457        for q in 2..16 {
458            let w = AllWire::rand(&mut rng, q);
459            let serialized = serde_json::to_string(&w).unwrap();
460
461            let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
462
463            assert_eq!(w, deserialized);
464        }
465    }
466}