1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11 q: u16, ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19 type Error = crate::errors::WireDeserializationError;
20
21 fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22 if wire.q < 2 {
24 return Err(Self::Error::InvalidWireModQ(
25 crate::errors::ModQDeserializationError::BadModulus(wire.q),
26 ));
27 }
28
29 let expected_len = crate::util::digits_per_u128(wire.q);
31 let given_len = wire.ds.len();
32 if given_len != expected_len {
33 return Err(Self::Error::InvalidWireModQ(
34 crate::errors::ModQDeserializationError::InvalidDigitsLength {
35 got: given_len,
36 needed: expected_len,
37 },
38 ));
39 }
40 if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
41 return Err(Self::Error::InvalidWireModQ(
42 crate::errors::ModQDeserializationError::DigitTooLarge {
43 digit: wire.ds[i],
44 modulus: wire.q,
45 },
46 ));
47 }
48 Ok(WireModQ {
49 q: wire.q,
50 ds: wire.ds,
51 })
52 }
53}
54
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
62#[derive(Debug, Clone, PartialEq, Default)]
63pub struct WireModQ {
64 q: u16,
66 pub(crate) ds: Vec<u16>,
68}
69
70impl HasModulus for WireModQ {
71 fn modulus(&self) -> u16 {
72 self.q
73 }
74}
75
76impl core::ops::Add for WireModQ {
77 type Output = Self;
78
79 fn add(self, rhs: Self) -> Self::Output {
80 assert_eq!(self.q, rhs.q);
81
82 let mut xs = self.ds.clone();
83 let ys = &rhs.ds;
84 let q = self.q;
85
86 debug_assert_eq!(xs.len(), ys.len());
87 xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
88 let (zp, overflow) = (*x + y).overflowing_sub(q);
89 *x = if overflow { *x + y } else { zp }
90 });
91 Self { ds: xs, q }
92 }
93}
94
95impl core::ops::AddAssign for WireModQ {
96 fn add_assign(&mut self, rhs: Self) {
97 assert_eq!(self.q, rhs.q);
98
99 let q = self.q;
100
101 debug_assert_eq!(self.ds.len(), rhs.ds.len());
102 self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
103 let (zp, overflow) = (*x + y).overflowing_sub(q);
104 *x = if overflow { *x + y } else { zp }
105 });
106 }
107}
108
109impl core::ops::Sub for WireModQ {
110 type Output = Self;
111
112 fn sub(self, rhs: Self) -> Self::Output {
113 self + -rhs
114 }
115}
116
117impl core::ops::SubAssign for WireModQ {
118 fn sub_assign(&mut self, rhs: Self) {
119 *self = self.clone() - rhs;
120 }
121}
122
123impl core::ops::Neg for WireModQ {
124 type Output = Self;
125
126 fn neg(self) -> Self::Output {
127 let q = self.q;
128 let mut ds = self.ds.clone();
129 ds.iter_mut().for_each(|d| {
130 if *d > 0 {
131 *d = q - *d;
132 } else {
133 *d = 0;
134 }
135 });
136 Self { q, ds }
137 }
138}
139
140impl core::ops::Mul<u16> for WireModQ {
141 type Output = Self;
142
143 fn mul(self, rhs: u16) -> Self::Output {
144 let q = self.q;
145 let mut ds = self.ds.clone();
146 ds.iter_mut()
147 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
148 Self { ds, q }
149 }
150}
151
152impl core::ops::MulAssign<u16> for WireModQ {
153 fn mul_assign(&mut self, rhs: u16) {
154 let q = self.q;
155 self.ds
156 .iter_mut()
157 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
158 }
159}
160
161impl WireLabel for WireModQ {
162 fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
163 if q < 2 {
164 panic!(
165 "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
166 q
167 );
168 }
169 let mut w = Self::rand(rng, q);
170 w.ds[0] = 1;
171 w
172 }
173
174 fn digits(&self) -> Vec<u16> {
175 self.ds.clone()
176 }
177
178 fn to_repr(&self) -> U8x16 {
179 util::from_base_q(&self.ds, self.q).into()
183 }
184
185 fn color(&self) -> u16 {
186 let color = self.ds[0];
187 debug_assert!(color < self.q);
188 color
189 }
190
191 fn from_repr(inp: U8x16, q: u16) -> Self {
192 if q < 2 {
193 panic!(
194 "[WireModQ::from_block] Modulus must be at least 2. Got {}",
195 q
196 );
197 }
198 let ds = if util::is_power_of_2(q) {
202 let ndigits = util::digits_per_u128(q);
204 let width = 128 / ndigits;
205 let mask = (1 << width) - 1;
206 let x = u128::from(inp);
207 (0..ndigits)
208 .map(|i| ((x >> (width * i)) & mask) as u16)
209 .collect::<Vec<u16>>()
210 } else if q <= 23 {
211 _unrank(u128::from(inp), q)
212 } else {
213 _unrank(u128::from(inp), q)
215 };
216 Self { q, ds }
217 }
218 fn zero(q: u16) -> Self {
221 if q < 2 {
222 panic!("[WireModQ::zero] Modulus must be at least 2. Got {}", q);
223 }
224 Self {
225 q,
226 ds: vec![0; util::digits_per_u128(q)],
227 }
228 }
229 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
230 if q < 2 {
231 panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
232 }
233 let ds = (0..util::digits_per_u128(q))
234 .map(|_| rng.r#gen::<u16>() % q)
235 .collect();
236 Self { q, ds }
237 }
238
239 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
240 if q < 2 {
241 panic!(
242 "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
243 q
244 );
245 }
246 Self::from_repr(hash, q)
247 }
248}
249
250impl ArithmeticWire for WireModQ {}
251
252#[cfg(test)]
253mod tests {
254 #[cfg(feature = "serde")]
255 use super::WireModQ;
256 #[cfg(feature = "serde")]
257 use crate::WireLabel;
258 #[cfg(feature = "serde")]
259 use rand::Rng;
260 #[cfg(feature = "serde")]
261 use rand::thread_rng;
262
263 #[cfg(feature = "serde")]
264 #[test]
265 fn test_serialize_good_modQ() {
266 let mut rng = thread_rng();
267
268 for _ in 0..16 {
269 let mut q: u16 = rng.r#gen();
270 while q < 2 {
271 q = rng.r#gen();
272 }
273 let w = WireModQ::rand(&mut rng, q);
274 let serialized = serde_json::to_string(&w).unwrap();
275
276 let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
277
278 assert_eq!(w, deserialized);
279 }
280 }
281 #[cfg(feature = "serde")]
282 #[test]
283 fn test_serialize_bad_modQ_mod() {
284 let mut rng = thread_rng();
285 let mut q: u16 = rng.r#gen();
286 while q < 2 {
287 q = rng.r#gen();
288 }
289
290 let mut w = WireModQ::rand(&mut rng, q);
291
292 w.q = 1;
294 let serialized = serde_json::to_string(&w).unwrap();
295
296 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
297 assert!(deserialized.is_err());
298 }
299 #[cfg(feature = "serde")]
300 #[test]
301 fn test_serialize_bad_modQ_ds_mod() {
302 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
303
304 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
305 assert!(deserialized.is_err());
306 }
307
308 #[cfg(feature = "serde")]
309 #[test]
310 fn test_serialize_bad_modQ_ds_count() {
311 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
312
313 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
314 assert!(deserialized.is_err());
315 }
316}