Skip to main content

fancy_garbling/wire/
modq.rs

1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5/// Intermediate struct to deserialize WireModQ to
6///
7/// Checks that modulus is at least 2
8#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11    /// The modulus of the wire label
12    q: u16, // Assuming mod can fit in u16
13    /// A list of `mod-q` digits.
14    ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19    type Error = swanky_error::Error;
20
21    fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22        swanky_error::ensure!(
23            wire.q >= 2,
24            swanky_error::ErrorKind::OtherError,
25            "Modulus must be at least two",
26        );
27
28        // Check correct length and make sure all values are less than the modulus
29        let expected_len = crate::util::digits_per_u128(wire.q);
30        let given_len = wire.ds.len();
31        swanky_error::ensure!(
32            given_len == expected_len,
33            swanky_error::ErrorKind::OtherError,
34            "Invalid number of digits. Expected: {expected_len}. Got: {given_len}"
35        );
36        if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
37            swanky_error::bail!(
38                swanky_error::ErrorKind::OtherError,
39                "Digit {i} is greater than the modulus",
40            );
41        }
42        Ok(WireModQ {
43            q: wire.q,
44            ds: wire.ds,
45        })
46    }
47}
48
49// Assuming mod can fit in u16
50/// Representation of a `mod-q` wire.
51///
52/// We represent a `mod-q` wire for `q > 3` by the modulus`q` alongside a
53/// list of `mod-q` digits.
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
56#[derive(Debug, Clone, PartialEq, Default)]
57pub struct WireModQ {
58    /// The modulus of the wire label
59    q: u16,
60    /// A list of `mod-q` digits.
61    pub(crate) ds: Vec<u16>,
62}
63
64impl HasModulus for WireModQ {
65    fn modulus(&self) -> u16 {
66        self.q
67    }
68}
69
70impl core::ops::Add for WireModQ {
71    type Output = Self;
72
73    fn add(self, rhs: Self) -> Self::Output {
74        assert_eq!(self.q, rhs.q);
75
76        let mut xs = self.ds.clone();
77        let ys = &rhs.ds;
78        let q = self.q;
79
80        debug_assert_eq!(xs.len(), ys.len());
81        xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
82            let (zp, overflow) = (*x + y).overflowing_sub(q);
83            *x = if overflow { *x + y } else { zp }
84        });
85        Self { ds: xs, q }
86    }
87}
88
89impl core::ops::AddAssign for WireModQ {
90    fn add_assign(&mut self, rhs: Self) {
91        assert_eq!(self.q, rhs.q);
92
93        let q = self.q;
94
95        debug_assert_eq!(self.ds.len(), rhs.ds.len());
96        self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
97            let (zp, overflow) = (*x + y).overflowing_sub(q);
98            *x = if overflow { *x + y } else { zp }
99        });
100    }
101}
102
103impl core::ops::Sub for WireModQ {
104    type Output = Self;
105
106    fn sub(self, rhs: Self) -> Self::Output {
107        self + -rhs
108    }
109}
110
111impl core::ops::SubAssign for WireModQ {
112    fn sub_assign(&mut self, rhs: Self) {
113        *self = self.clone() - rhs;
114    }
115}
116
117impl core::ops::Neg for WireModQ {
118    type Output = Self;
119
120    fn neg(self) -> Self::Output {
121        let q = self.q;
122        let mut ds = self.ds.clone();
123        ds.iter_mut().for_each(|d| {
124            if *d > 0 {
125                *d = q - *d;
126            } else {
127                *d = 0;
128            }
129        });
130        Self { q, ds }
131    }
132}
133
134impl core::ops::Mul<u16> for WireModQ {
135    type Output = Self;
136
137    fn mul(self, rhs: u16) -> Self::Output {
138        let q = self.q;
139        let mut ds = self.ds.clone();
140        ds.iter_mut()
141            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
142        Self { ds, q }
143    }
144}
145
146impl core::ops::MulAssign<u16> for WireModQ {
147    fn mul_assign(&mut self, rhs: u16) {
148        let q = self.q;
149        self.ds
150            .iter_mut()
151            .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
152    }
153}
154
155impl WireLabel for WireModQ {
156    fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
157        if q < 2 {
158            panic!(
159                "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
160                q
161            );
162        }
163        let mut w = Self::rand(rng, q);
164        w.ds[0] = 1;
165        w
166    }
167
168    fn to_repr(&self) -> U8x16 {
169        // This function converts a [`WireMod3`] into its [`Block`] representation.
170        // The values stored in [`WireModQ`] are repacked depending on q
171        // into a 128b value as a [`Block`].
172        util::from_base_q(&self.ds, self.q).into()
173    }
174
175    fn color(&self) -> u16 {
176        let color = self.ds[0];
177        debug_assert!(color < self.q);
178        color
179    }
180
181    fn from_repr(inp: U8x16, q: u16) -> Self {
182        if q < 2 {
183            panic!(
184                "[WireModQ::from_block] Modulus must be at least 2. Got {}",
185                q
186            );
187        }
188        // This function converts a Block into its WireLabel representation
189        // by splitting the Block into several digits mod q that can each fit
190        // into 128b.
191        let ds = if util::is_power_of_2(q) {
192            // It's a power of 2, just split the digits.
193            let ndigits = util::digits_per_u128(q);
194            let width = 128 / ndigits;
195            let mask = (1 << width) - 1;
196            let x = u128::from(inp);
197            (0..ndigits)
198                .map(|i| ((x >> (width * i)) & mask) as u16)
199                .collect::<Vec<u16>>()
200        } else if q <= 23 {
201            _unrank(u128::from(inp), q)
202        } else {
203            // If all else fails, do unrank using naive division.
204            _unrank(u128::from(inp), q)
205        };
206        Self { q, ds }
207    }
208
209    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
210        if q < 2 {
211            panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
212        }
213        let ds = (0..util::digits_per_u128(q))
214            .map(|_| rng.r#gen::<u16>() % q)
215            .collect();
216        Self { q, ds }
217    }
218
219    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
220        if q < 2 {
221            panic!(
222                "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
223                q
224            );
225        }
226        Self::from_repr(hash, q)
227    }
228}
229
230impl ArithmeticWire for WireModQ {}
231
232#[cfg(test)]
233mod tests {
234    #[cfg(feature = "serde")]
235    use super::WireModQ;
236    #[cfg(feature = "serde")]
237    use crate::WireLabel;
238    #[cfg(feature = "serde")]
239    use rand::Rng;
240    #[cfg(feature = "serde")]
241    use rand::thread_rng;
242
243    #[cfg(feature = "serde")]
244    #[test]
245    fn test_serialize_good_modQ() {
246        let mut rng = thread_rng();
247
248        for _ in 0..16 {
249            let mut q: u16 = rng.r#gen();
250            while q < 2 {
251                q = rng.r#gen();
252            }
253            let w = WireModQ::rand(&mut rng, q);
254            let serialized = serde_json::to_string(&w).unwrap();
255
256            let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
257
258            assert_eq!(w, deserialized);
259        }
260    }
261    #[cfg(feature = "serde")]
262    #[test]
263    fn test_serialize_bad_modQ_mod() {
264        let mut rng = thread_rng();
265        let mut q: u16 = rng.r#gen();
266        while q < 2 {
267            q = rng.r#gen();
268        }
269
270        let mut w = WireModQ::rand(&mut rng, q);
271
272        // Manually mess with the modulus
273        w.q = 1;
274        let serialized = serde_json::to_string(&w).unwrap();
275
276        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
277        assert!(deserialized.is_err());
278    }
279    #[cfg(feature = "serde")]
280    #[test]
281    fn test_serialize_bad_modQ_ds_mod() {
282        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
283
284        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
285        assert!(deserialized.is_err());
286    }
287
288    #[cfg(feature = "serde")]
289    #[test]
290    fn test_serialize_bad_modQ_ds_count() {
291        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
292
293        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
294        assert!(deserialized.is_err());
295    }
296}