fancy_garbling/wire/
modq.rs

1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5/// Intermediate struct to deserialize WireModQ to
6///
7/// Checks that modulus is at least 2
8#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11    /// The modulus of the wire label
12    q: u16, // Assuming mod can fit in u16
13    /// A list of `mod-q` digits.
14    ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19    type Error = swanky_error::Error;
20
21    fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22        swanky_error::ensure!(
23            wire.q >= 2,
24            swanky_error::ErrorKind::OtherError,
25            "Modulus must be at least two",
26        );
27
28        // Check correct length and make sure all values are less than the modulus
29        let expected_len = crate::util::digits_per_u128(wire.q);
30        let given_len = wire.ds.len();
31        swanky_error::ensure!(
32            given_len == expected_len,
33            swanky_error::ErrorKind::OtherError,
34            "Invalid number of digits. Expected: {expected_len}. Got: {given_len}"
35        );
36        if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
37            swanky_error::bail!(
38                swanky_error::ErrorKind::OtherError,
39                "Digit {i} is greater than the modulus",
40            );
41        }
42        Ok(WireModQ {
43            q: wire.q,
44            ds: wire.ds,
45        })
46    }
47}
48
49// Assuming mod can fit in u16
50/// Representation of a `mod-q` wire.
51///
52/// We represent a `mod-q` wire for `q > 3` by the modulus`q` alongside a
53/// list of `mod-q` digits.
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
56#[derive(Debug, Clone, PartialEq, Default)]
57pub struct WireModQ {
58    /// The modulus of the wire label
59    q: u16,
60    /// A list of `mod-q` digits.
61    pub(crate) ds: Vec<u16>,
62}
63
64impl HasModulus for WireModQ {
65    fn modulus(&self) -> u16 {
66        self.q
67    }
68}
69
70impl core::ops::Add for WireModQ {
71    type Output = Self;
72
73    fn add(self, rhs: Self) -> Self::Output {
74        assert_eq!(self.q, rhs.q);
75
76        let mut xs = self.ds.clone();
77        let ys = &rhs.ds;
78        let q = self.q;
79
80        debug_assert_eq!(xs.len(), ys.len());
81        xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
82            let (zp, overflow) = (*x + y).overflowing_sub(q);
83            *x = if overflow { *x + y } else { zp }
84        });
85        Self { ds: xs, q }
86    }
87}
88
89impl core::ops::AddAssign for WireModQ {
90    fn add_assign(&mut self, rhs: Self) {
91        assert_eq!(self.q, rhs.q);
92
93        let q = self.q;
94
95        debug_assert_eq!(self.ds.len(), rhs.ds.len());
96        self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
97            let (zp, overflow) = (*x + y).overflowing_sub(q);
98            *x = if overflow { *x + y } else { zp }
99        });
100    }
101}
102
103impl core::ops::Sub for WireModQ {
104    type Output = Self;
105
106    fn sub(self, rhs: Self) -> Self::Output {
107        self + -rhs
108    }
109}
110
111impl core::ops::SubAssign for WireModQ {
112    fn sub_assign(&mut self, rhs: Self) {
113        *self = self.clone() - rhs;
114    }
115}
116
117impl core::ops::Neg for WireModQ {
118    type Output = Self;
119
120    fn neg(self) -> Self::Output {
121        let q = self.q;
122        let mut ds = self.ds.clone();
123        ds.iter_mut().for_each(|d| {
124            if *d > 0 {
125                *d = q - *d;
126            } else {
127                *d = 0;
128            }
129        });
130        Self { q, ds }
131    }
132}
133
134impl core::ops::Mul<u16> for WireModQ {
135    type Output = Self;
136
137    fn mul(self, rhs: u16) -> Self::Output {
138        let q = self.q;
139        let mut ds = self.ds.clone();
140        ds.iter_mut()
141            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
142        Self { ds, q }
143    }
144}
145
146impl core::ops::MulAssign<u16> for WireModQ {
147    fn mul_assign(&mut self, rhs: u16) {
148        let q = self.q;
149        self.ds
150            .iter_mut()
151            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
152    }
153}
154
155impl WireLabel for WireModQ {
156    fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
157        if q < 2 {
158            panic!(
159                "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
160                q
161            );
162        }
163        let mut w = Self::rand(rng, q);
164        w.ds[0] = 1;
165        w
166    }
167
168    fn digits(&self) -> Vec<u16> {
169        self.ds.clone()
170    }
171
172    fn to_repr(&self) -> U8x16 {
173        // This function converts a [`WireMod3`] into its [`Block`] representation.
174        // The values stored in [`WireModQ`] are repacked depending on q
175        // into a 128b value as a [`Block`].
176        util::from_base_q(&self.ds, self.q).into()
177    }
178
179    fn color(&self) -> u16 {
180        let color = self.ds[0];
181        debug_assert!(color < self.q);
182        color
183    }
184
185    fn from_repr(inp: U8x16, q: u16) -> Self {
186        if q < 2 {
187            panic!(
188                "[WireModQ::from_block] Modulus must be at least 2. Got {}",
189                q
190            );
191        }
192        // This function converts a Block into its WireLabel representation
193        // by splitting the Block into several digits mod q that can each fit
194        // into 128b.
195        let ds = if util::is_power_of_2(q) {
196            // It's a power of 2, just split the digits.
197            let ndigits = util::digits_per_u128(q);
198            let width = 128 / ndigits;
199            let mask = (1 << width) - 1;
200            let x = u128::from(inp);
201            (0..ndigits)
202                .map(|i| ((x >> (width * i)) & mask) as u16)
203                .collect::<Vec<u16>>()
204        } else if q <= 23 {
205            _unrank(u128::from(inp), q)
206        } else {
207            // If all else fails, do unrank using naive division.
208            _unrank(u128::from(inp), q)
209        };
210        Self { q, ds }
211    }
212
213    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
214        if q < 2 {
215            panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
216        }
217        let ds = (0..util::digits_per_u128(q))
218            .map(|_| rng.r#gen::<u16>() % q)
219            .collect();
220        Self { q, ds }
221    }
222
223    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
224        if q < 2 {
225            panic!(
226                "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
227                q
228            );
229        }
230        Self::from_repr(hash, q)
231    }
232}
233
234impl ArithmeticWire for WireModQ {}
235
236#[cfg(test)]
237mod tests {
238    #[cfg(feature = "serde")]
239    use super::WireModQ;
240    #[cfg(feature = "serde")]
241    use crate::WireLabel;
242    #[cfg(feature = "serde")]
243    use rand::Rng;
244    #[cfg(feature = "serde")]
245    use rand::thread_rng;
246
247    #[cfg(feature = "serde")]
248    #[test]
249    fn test_serialize_good_modQ() {
250        let mut rng = thread_rng();
251
252        for _ in 0..16 {
253            let mut q: u16 = rng.r#gen();
254            while q < 2 {
255                q = rng.r#gen();
256            }
257            let w = WireModQ::rand(&mut rng, q);
258            let serialized = serde_json::to_string(&w).unwrap();
259
260            let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
261
262            assert_eq!(w, deserialized);
263        }
264    }
265    #[cfg(feature = "serde")]
266    #[test]
267    fn test_serialize_bad_modQ_mod() {
268        let mut rng = thread_rng();
269        let mut q: u16 = rng.r#gen();
270        while q < 2 {
271            q = rng.r#gen();
272        }
273
274        let mut w = WireModQ::rand(&mut rng, q);
275
276        // Manually mess with the modulus
277        w.q = 1;
278        let serialized = serde_json::to_string(&w).unwrap();
279
280        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
281        assert!(deserialized.is_err());
282    }
283    #[cfg(feature = "serde")]
284    #[test]
285    fn test_serialize_bad_modQ_ds_mod() {
286        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
287
288        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
289        assert!(deserialized.is_err());
290    }
291
292    #[cfg(feature = "serde")]
293    #[test]
294    fn test_serialize_bad_modQ_ds_count() {
295        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
296
297        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
298        assert!(deserialized.is_err());
299    }
300}