1use crate::{ArithmeticWire, HasModulus, WireLabel, util, wire::_unrank};
2use rand::{CryptoRng, Rng, RngCore};
3use vectoreyes::U8x16;
4
5#[cfg(feature = "serde")]
9#[derive(serde::Deserialize)]
10struct UntrustedWireModQ {
11 q: u16, ds: Vec<u16>,
15}
16
17#[cfg(feature = "serde")]
18impl TryFrom<UntrustedWireModQ> for WireModQ {
19 type Error = swanky_error::Error;
20
21 fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
22 swanky_error::ensure!(
23 wire.q >= 2,
24 swanky_error::ErrorKind::OtherError,
25 "Modulus must be at least two",
26 );
27
28 let expected_len = crate::util::digits_per_u128(wire.q);
30 let given_len = wire.ds.len();
31 swanky_error::ensure!(
32 given_len == expected_len,
33 swanky_error::ErrorKind::OtherError,
34 "Invalid number of digits. Expected: {expected_len}. Got: {given_len}"
35 );
36 if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
37 swanky_error::bail!(
38 swanky_error::ErrorKind::OtherError,
39 "Digit {i} is greater than the modulus",
40 );
41 }
42 Ok(WireModQ {
43 q: wire.q,
44 ds: wire.ds,
45 })
46 }
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
56#[derive(Debug, Clone, PartialEq, Default)]
57pub struct WireModQ {
58 q: u16,
60 pub(crate) ds: Vec<u16>,
62}
63
64impl HasModulus for WireModQ {
65 fn modulus(&self) -> u16 {
66 self.q
67 }
68}
69
70impl core::ops::Add for WireModQ {
71 type Output = Self;
72
73 fn add(self, rhs: Self) -> Self::Output {
74 assert_eq!(self.q, rhs.q);
75
76 let mut xs = self.ds.clone();
77 let ys = &rhs.ds;
78 let q = self.q;
79
80 debug_assert_eq!(xs.len(), ys.len());
81 xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
82 let (zp, overflow) = (*x + y).overflowing_sub(q);
83 *x = if overflow { *x + y } else { zp }
84 });
85 Self { ds: xs, q }
86 }
87}
88
89impl core::ops::AddAssign for WireModQ {
90 fn add_assign(&mut self, rhs: Self) {
91 assert_eq!(self.q, rhs.q);
92
93 let q = self.q;
94
95 debug_assert_eq!(self.ds.len(), rhs.ds.len());
96 self.ds.iter_mut().zip(rhs.ds.iter()).for_each(|(x, &y)| {
97 let (zp, overflow) = (*x + y).overflowing_sub(q);
98 *x = if overflow { *x + y } else { zp }
99 });
100 }
101}
102
103impl core::ops::Sub for WireModQ {
104 type Output = Self;
105
106 fn sub(self, rhs: Self) -> Self::Output {
107 self + -rhs
108 }
109}
110
111impl core::ops::SubAssign for WireModQ {
112 fn sub_assign(&mut self, rhs: Self) {
113 *self = self.clone() - rhs;
114 }
115}
116
117impl core::ops::Neg for WireModQ {
118 type Output = Self;
119
120 fn neg(self) -> Self::Output {
121 let q = self.q;
122 let mut ds = self.ds.clone();
123 ds.iter_mut().for_each(|d| {
124 if *d > 0 {
125 *d = q - *d;
126 } else {
127 *d = 0;
128 }
129 });
130 Self { q, ds }
131 }
132}
133
134impl core::ops::Mul<u16> for WireModQ {
135 type Output = Self;
136
137 fn mul(self, rhs: u16) -> Self::Output {
138 let q = self.q;
139 let mut ds = self.ds.clone();
140 ds.iter_mut()
141 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
142 Self { ds, q }
143 }
144}
145
146impl core::ops::MulAssign<u16> for WireModQ {
147 fn mul_assign(&mut self, rhs: u16) {
148 let q = self.q;
149 self.ds
150 .iter_mut()
151 .for_each(|d| *d = (*d as u32 * rhs as u32 % q as u32) as u16);
152 }
153}
154
155impl WireLabel for WireModQ {
156 fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
157 if q < 2 {
158 panic!(
159 "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
160 q
161 );
162 }
163 let mut w = Self::rand(rng, q);
164 w.ds[0] = 1;
165 w
166 }
167
168 fn to_repr(&self) -> U8x16 {
169 util::from_base_q(&self.ds, self.q).into()
173 }
174
175 fn color(&self) -> u16 {
176 let color = self.ds[0];
177 debug_assert!(color < self.q);
178 color
179 }
180
181 fn from_repr(inp: U8x16, q: u16) -> Self {
182 if q < 2 {
183 panic!(
184 "[WireModQ::from_block] Modulus must be at least 2. Got {}",
185 q
186 );
187 }
188 let ds = if util::is_power_of_2(q) {
192 let ndigits = util::digits_per_u128(q);
194 let width = 128 / ndigits;
195 let mask = (1 << width) - 1;
196 let x = u128::from(inp);
197 (0..ndigits)
198 .map(|i| ((x >> (width * i)) & mask) as u16)
199 .collect::<Vec<u16>>()
200 } else if q <= 23 {
201 _unrank(u128::from(inp), q)
202 } else {
203 _unrank(u128::from(inp), q)
205 };
206 Self { q, ds }
207 }
208
209 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
210 if q < 2 {
211 panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
212 }
213 let ds = (0..util::digits_per_u128(q))
214 .map(|_| rng.r#gen::<u16>() % q)
215 .collect();
216 Self { q, ds }
217 }
218
219 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
220 if q < 2 {
221 panic!(
222 "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
223 q
224 );
225 }
226 Self::from_repr(hash, q)
227 }
228}
229
230impl ArithmeticWire for WireModQ {}
231
232#[cfg(test)]
233mod tests {
234 #[cfg(feature = "serde")]
235 use super::WireModQ;
236 #[cfg(feature = "serde")]
237 use crate::WireLabel;
238 #[cfg(feature = "serde")]
239 use rand::Rng;
240 #[cfg(feature = "serde")]
241 use rand::thread_rng;
242
243 #[cfg(feature = "serde")]
244 #[test]
245 fn test_serialize_good_modQ() {
246 let mut rng = thread_rng();
247
248 for _ in 0..16 {
249 let mut q: u16 = rng.r#gen();
250 while q < 2 {
251 q = rng.r#gen();
252 }
253 let w = WireModQ::rand(&mut rng, q);
254 let serialized = serde_json::to_string(&w).unwrap();
255
256 let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
257
258 assert_eq!(w, deserialized);
259 }
260 }
261 #[cfg(feature = "serde")]
262 #[test]
263 fn test_serialize_bad_modQ_mod() {
264 let mut rng = thread_rng();
265 let mut q: u16 = rng.r#gen();
266 while q < 2 {
267 q = rng.r#gen();
268 }
269
270 let mut w = WireModQ::rand(&mut rng, q);
271
272 w.q = 1;
274 let serialized = serde_json::to_string(&w).unwrap();
275
276 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
277 assert!(deserialized.is_err());
278 }
279 #[cfg(feature = "serde")]
280 #[test]
281 fn test_serialize_bad_modQ_ds_mod() {
282 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
283
284 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
285 assert!(deserialized.is_err());
286 }
287
288 #[cfg(feature = "serde")]
289 #[test]
290 fn test_serialize_bad_modQ_ds_count() {
291 let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
292
293 let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
294 assert!(deserialized.is_err());
295 }
296}