1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11 q: u16, ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19 type Error = swanky_error::Error;
20
21 fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22 swanky_error::ensure!(
23 wire.q >= 2,
24 swanky_error::ErrorKind::OtherError,
25 "Modulus must be at least two",
26 );
27
28 let expected_len = crate::util::digits_per_u128(wire.q);
30 let given_len = wire.ds.len();
31 swanky_error::ensure!(
32 given_len == expected_len,
33 swanky_error::ErrorKind::OtherError,
34 "Invalid number of digits. Expected: {expected_len}. Got: {given_len}"
35 );
36 if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
37 swanky_error::bail!(
38 swanky_error::ErrorKind::OtherError,
39 "Digit {i} is greater than the modulus",
40 );
41 }
42 Ok(WireModQ {
43 q: wire.q,
44 ds: wire.ds,
45 })
46 }
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
56#[derive(Debug, Clone, PartialEq, Default)]
57pub struct WireModQ {
58 q: u16,
60 pub(crate) ds: Vec<u16>,
62}
63
64impl HasModulus for WireModQ {
65 fn modulus(&self) -> u16 {
66 self.q
67 }
68}
69
70impl core::ops::Add for WireModQ {
71 type Output = Self;
72
73 fn add(self, rhs: Self) -> Self::Output {
74 assert_eq!(self.q, rhs.q);
75
76 let mut xs = self.ds.clone();
77 let ys = &rhs.ds;
78 let q = self.q;
79
80 debug_assert_eq!(xs.len(), ys.len());
81 xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
82 let (zp, overflow) = (*x + y).overflowing_sub(q);
83 *x = if overflow { *x + y } else { zp }
84 });
85 Self { ds: xs, q }
86 }
87}
88
89impl core::ops::AddAssign for WireModQ {
90 fn add_assign(&mut self, rhs: Self) {
91 assert_eq!(self.q, rhs.q);
92
93 let q = self.q;
94
95 debug_assert_eq!(self.ds.len(), rhs.ds.len());
96 self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
97 let (zp, overflow) = (*x + y).overflowing_sub(q);
98 *x = if overflow { *x + y } else { zp }
99 });
100 }
101}
102
103impl core::ops::Sub for WireModQ {
104 type Output = Self;
105
106 fn sub(self, rhs: Self) -> Self::Output {
107 self + -rhs
108 }
109}
110
111impl core::ops::SubAssign for WireModQ {
112 fn sub_assign(&mut self, rhs: Self) {
113 *self = self.clone() - rhs;
114 }
115}
116
117impl core::ops::Neg for WireModQ {
118 type Output = Self;
119
120 fn neg(self) -> Self::Output {
121 let q = self.q;
122 let mut ds = self.ds.clone();
123 ds.iter_mut().for_each(|d| {
124 if *d > 0 {
125 *d = q - *d;
126 } else {
127 *d = 0;
128 }
129 });
130 Self { q, ds }
131 }
132}
133
134impl core::ops::Mul<u16> for WireModQ {
135 type Output = Self;
136
137 fn mul(self, rhs: u16) -> Self::Output {
138 let q = self.q;
139 let mut ds = self.ds.clone();
140 ds.iter_mut()
141 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
142 Self { ds, q }
143 }
144}
145
146impl core::ops::MulAssign<u16> for WireModQ {
147 fn mul_assign(&mut self, rhs: u16) {
148 let q = self.q;
149 self.ds
150 .iter_mut()
151 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
152 }
153}
154
155impl WireLabel for WireModQ {
156 fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
157 if q < 2 {
158 panic!(
159 "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
160 q
161 );
162 }
163 let mut w = Self::rand(rng, q);
164 w.ds[0] = 1;
165 w
166 }
167
168 fn digits(&self) -> Vec<u16> {
169 self.ds.clone()
170 }
171
172 fn to_repr(&self) -> U8x16 {
173 util::from_base_q(&self.ds, self.q).into()
177 }
178
179 fn color(&self) -> u16 {
180 let color = self.ds[0];
181 debug_assert!(color < self.q);
182 color
183 }
184
185 fn from_repr(inp: U8x16, q: u16) -> Self {
186 if q < 2 {
187 panic!(
188 "[WireModQ::from_block] Modulus must be at least 2. Got {}",
189 q
190 );
191 }
192 let ds = if util::is_power_of_2(q) {
196 let ndigits = util::digits_per_u128(q);
198 let width = 128 / ndigits;
199 let mask = (1 << width) - 1;
200 let x = u128::from(inp);
201 (0..ndigits)
202 .map(|i| ((x >> (width * i)) & mask) as u16)
203 .collect::<Vec<u16>>()
204 } else if q <= 23 {
205 _unrank(u128::from(inp), q)
206 } else {
207 _unrank(u128::from(inp), q)
209 };
210 Self { q, ds }
211 }
212
213 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
214 if q < 2 {
215 panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
216 }
217 let ds = (0..util::digits_per_u128(q))
218 .map(|_| rng.r#gen::<u16>() % q)
219 .collect();
220 Self { q, ds }
221 }
222
223 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
224 if q < 2 {
225 panic!(
226 "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
227 q
228 );
229 }
230 Self::from_repr(hash, q)
231 }
232}
233
234impl ArithmeticWire for WireModQ {}
235
236#[cfg(test)]
237mod tests {
238 #[cfg(feature = "serde")]
239 use super::WireModQ;
240 #[cfg(feature = "serde")]
241 use crate::WireLabel;
242 #[cfg(feature = "serde")]
243 use rand::Rng;
244 #[cfg(feature = "serde")]
245 use rand::thread_rng;
246
247 #[cfg(feature = "serde")]
248 #[test]
249 fn test_serialize_good_modQ() {
250 let mut rng = thread_rng();
251
252 for _ in 0..16 {
253 let mut q: u16 = rng.r#gen();
254 while q < 2 {
255 q = rng.r#gen();
256 }
257 let w = WireModQ::rand(&mut rng, q);
258 let serialized = serde_json::to_string(&w).unwrap();
259
260 let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
261
262 assert_eq!(w, deserialized);
263 }
264 }
265 #[cfg(feature = "serde")]
266 #[test]
267 fn test_serialize_bad_modQ_mod() {
268 let mut rng = thread_rng();
269 let mut q: u16 = rng.r#gen();
270 while q < 2 {
271 q = rng.r#gen();
272 }
273
274 let mut w = WireModQ::rand(&mut rng, q);
275
276 w.q = 1;
278 let serialized = serde_json::to_string(&w).unwrap();
279
280 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
281 assert!(deserialized.is_err());
282 }
283 #[cfg(feature = "serde")]
284 #[test]
285 fn test_serialize_bad_modQ_ds_mod() {
286 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
287
288 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
289 assert!(deserialized.is_err());
290 }
291
292 #[cfg(feature = "serde")]
293 #[test]
294 fn test_serialize_bad_modQ_ds_count() {
295 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
296
297 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
298 assert!(deserialized.is_err());
299 }
300}