fancy_garbling/wire/
modq.rs

1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5/// Intermediate struct to deserialize WireModQ to
6///
7/// Checks that modulus is at least 2
8#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11    /// The modulus of the wire label
12    q: u16, // Assuming mod can fit in u16
13    /// A list of `mod-q` digits.
14    ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19    type Error = crate::errors::WireDeserializationError;
20
21    fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22        // Modulus must be at least 2
23        if wire.q < 2 {
24            return Err(Self::Error::InvalidWireModQ(
25                crate::errors::ModQDeserializationError::BadModulus(wire.q),
26            ));
27        }
28
29        // Check correct length and make sure all values are less than the modulus
30        let expected_len = crate::util::digits_per_u128(wire.q);
31        let given_len = wire.ds.len();
32        if given_len != expected_len {
33            return Err(Self::Error::InvalidWireModQ(
34                crate::errors::ModQDeserializationError::InvalidDigitsLength {
35                    got: given_len,
36                    needed: expected_len,
37                },
38            ));
39        }
40        if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
41            return Err(Self::Error::InvalidWireModQ(
42                crate::errors::ModQDeserializationError::DigitTooLarge {
43                    digit: wire.ds[i],
44                    modulus: wire.q,
45                },
46            ));
47        }
48        Ok(WireModQ {
49            q: wire.q,
50            ds: wire.ds,
51        })
52    }
53}
54
55// Assuming mod can fit in u16
56/// Representation of a `mod-q` wire.
57///
58/// We represent a `mod-q` wire for `q > 3` by the modulus`q` alongside a
59/// list of `mod-q` digits.
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
62#[derive(Debug, Clone, PartialEq, Default)]
63pub struct WireModQ {
64    /// The modulus of the wire label
65    q: u16,
66    /// A list of `mod-q` digits.
67    pub(crate) ds: Vec<u16>,
68}
69
70impl HasModulus for WireModQ {
71    fn modulus(&self) -> u16 {
72        self.q
73    }
74}
75
76impl core::ops::Add for WireModQ {
77    type Output = Self;
78
79    fn add(self, rhs: Self) -> Self::Output {
80        assert_eq!(self.q, rhs.q);
81
82        let mut xs = self.ds.clone();
83        let ys = &rhs.ds;
84        let q = self.q;
85
86        debug_assert_eq!(xs.len(), ys.len());
87        xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
88            let (zp, overflow) = (*x + y).overflowing_sub(q);
89            *x = if overflow { *x + y } else { zp }
90        });
91        Self { ds: xs, q }
92    }
93}
94
95impl core::ops::AddAssign for WireModQ {
96    fn add_assign(&mut self, rhs: Self) {
97        assert_eq!(self.q, rhs.q);
98
99        let q = self.q;
100
101        debug_assert_eq!(self.ds.len(), rhs.ds.len());
102        self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
103            let (zp, overflow) = (*x + y).overflowing_sub(q);
104            *x = if overflow { *x + y } else { zp }
105        });
106    }
107}
108
109impl core::ops::Sub for WireModQ {
110    type Output = Self;
111
112    fn sub(self, rhs: Self) -> Self::Output {
113        self + -rhs
114    }
115}
116
117impl core::ops::SubAssign for WireModQ {
118    fn sub_assign(&mut self, rhs: Self) {
119        *self = self.clone() - rhs;
120    }
121}
122
123impl core::ops::Neg for WireModQ {
124    type Output = Self;
125
126    fn neg(self) -> Self::Output {
127        let q = self.q;
128        let mut ds = self.ds.clone();
129        ds.iter_mut().for_each(|d| {
130            if *d > 0 {
131                *d = q - *d;
132            } else {
133                *d = 0;
134            }
135        });
136        Self { q, ds }
137    }
138}
139
140impl core::ops::Mul<u16> for WireModQ {
141    type Output = Self;
142
143    fn mul(self, rhs: u16) -> Self::Output {
144        let q = self.q;
145        let mut ds = self.ds.clone();
146        ds.iter_mut()
147            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
148        Self { ds, q }
149    }
150}
151
152impl core::ops::MulAssign<u16> for WireModQ {
153    fn mul_assign(&mut self, rhs: u16) {
154        let q = self.q;
155        self.ds
156            .iter_mut()
157            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
158    }
159}
160
161impl WireLabel for WireModQ {
162    fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
163        if q < 2 {
164            panic!(
165                "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
166                q
167            );
168        }
169        let mut w = Self::rand(rng, q);
170        w.ds[0] = 1;
171        w
172    }
173
174    fn digits(&self) -> Vec<u16> {
175        self.ds.clone()
176    }
177
178    fn to_repr(&self) -> U8x16 {
179        // This function converts a [`WireMod3`] into its [`Block`] representation.
180        // The values stored in [`WireModQ`] are repacked depending on q
181        // into a 128b value as a [`Block`].
182        util::from_base_q(&self.ds, self.q).into()
183    }
184
185    fn color(&self) -> u16 {
186        let color = self.ds[0];
187        debug_assert!(color < self.q);
188        color
189    }
190
191    fn from_repr(inp: U8x16, q: u16) -> Self {
192        if q < 2 {
193            panic!(
194                "[WireModQ::from_block] Modulus must be at least 2. Got {}",
195                q
196            );
197        }
198        // This function converts a Block into its WireLabel representation
199        // by splitting the Block into several digits mod q that can each fit
200        // into 128b.
201        let ds = if util::is_power_of_2(q) {
202            // It's a power of 2, just split the digits.
203            let ndigits = util::digits_per_u128(q);
204            let width = 128 / ndigits;
205            let mask = (1 << width) - 1;
206            let x = u128::from(inp);
207            (0..ndigits)
208                .map(|i| ((x >> (width * i)) & mask) as u16)
209                .collect::<Vec<u16>>()
210        } else if q <= 23 {
211            _unrank(u128::from(inp), q)
212        } else {
213            // If all else fails, do unrank using naive division.
214            _unrank(u128::from(inp), q)
215        };
216        Self { q, ds }
217    }
218    /// Unpack the wire represented by a `Block` with modulus `q`. Assumes that
219    /// the block was constructed through the `AllWire` API.
220    fn zero(q: u16) -> Self {
221        if q < 2 {
222            panic!("[WireModQ::zero] Modulus must be at least 2. Got {}", q);
223        }
224        Self {
225            q,
226            ds: vec![0; util::digits_per_u128(q)],
227        }
228    }
229    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
230        if q < 2 {
231            panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
232        }
233        let ds = (0..util::digits_per_u128(q))
234            .map(|_| rng.r#gen::<u16>() % q)
235            .collect();
236        Self { q, ds }
237    }
238
239    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
240        if q < 2 {
241            panic!(
242                "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
243                q
244            );
245        }
246        Self::from_repr(hash, q)
247    }
248}
249
250impl ArithmeticWire for WireModQ {}
251
252#[cfg(test)]
253mod tests {
254    #[cfg(feature = "serde")]
255    use super::WireModQ;
256    #[cfg(feature = "serde")]
257    use crate::WireLabel;
258    #[cfg(feature = "serde")]
259    use rand::Rng;
260    #[cfg(feature = "serde")]
261    use rand::thread_rng;
262
263    #[cfg(feature = "serde")]
264    #[test]
265    fn test_serialize_good_modQ() {
266        let mut rng = thread_rng();
267
268        for _ in 0..16 {
269            let mut q: u16 = rng.r#gen();
270            while q < 2 {
271                q = rng.r#gen();
272            }
273            let w = WireModQ::rand(&mut rng, q);
274            let serialized = serde_json::to_string(&w).unwrap();
275
276            let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
277
278            assert_eq!(w, deserialized);
279        }
280    }
281    #[cfg(feature = "serde")]
282    #[test]
283    fn test_serialize_bad_modQ_mod() {
284        let mut rng = thread_rng();
285        let mut q: u16 = rng.r#gen();
286        while q < 2 {
287            q = rng.r#gen();
288        }
289
290        let mut w = WireModQ::rand(&mut rng, q);
291
292        // Manually mess with the modulus
293        w.q = 1;
294        let serialized = serde_json::to_string(&w).unwrap();
295
296        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
297        assert!(deserialized.is_err());
298    }
299    #[cfg(feature = "serde")]
300    #[test]
301    fn test_serialize_bad_modQ_ds_mod() {
302        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
303
304        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
305        assert!(deserialized.is_err());
306    }
307
308    #[cfg(feature = "serde")]
309    #[test]
310    fn test_serialize_bad_modQ_ds_count() {
311        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
312
313        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
314        assert!(deserialized.is_err());
315    }
316}