1use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_aes_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12 U8x16,
13 array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27 ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29 let batch = wires.array_map(|x| x.to_repr());
30 TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33pub trait ArithmeticWire: Clone {}
36
37pub trait WireLabel:
42 Clone
43 + HasModulus
44 + core::ops::Add<Output = Self>
45 + core::ops::AddAssign
46 + core::ops::Sub<Output = Self>
47 + core::ops::SubAssign
48 + core::ops::Neg<Output = Self>
49 + core::ops::Mul<u16, Output = Self>
50 + core::ops::MulAssign<u16>
51{
52 fn digits(&self) -> Vec<u16>;
54
55 fn to_repr(&self) -> U8x16;
57
58 fn color(&self) -> u16;
60
61 fn from_repr(inp: U8x16, q: u16) -> Self;
68
69 fn zero(q: u16) -> Self;
78
79 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
85
86 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
92
93 fn hash_to_mod(hash: U8x16, q: u16) -> Self;
102
103 fn hashback(&self, tweak: u128, q: u16) -> Self {
114 let hash = self.hash(tweak);
115 Self::hash_to_mod(hash, q)
116 }
117
118 #[inline(never)]
120 fn hash(&self, tweak: u128) -> U8x16 {
121 TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
122 }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127pub enum AllWire {
129 Mod2(WireMod2),
131 Mod3(WireMod3),
133 ModN(WireModQ),
135}
136
137impl HasModulus for AllWire {
138 fn modulus(&self) -> u16 {
139 match &self {
140 AllWire::Mod2(x) => x.modulus(),
141 AllWire::Mod3(x) => x.modulus(),
142 AllWire::ModN(x) => x.modulus(),
143 }
144 }
145}
146
147impl core::ops::Add for AllWire {
148 type Output = Self;
149
150 fn add(self, rhs: Self) -> Self::Output {
151 let (p, q) = (self.modulus(), rhs.modulus());
152 match (self, rhs) {
153 (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
154 (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
155 (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
156 _ => panic!("unequal moduli: {p} != {q}"),
157 }
158 }
159}
160
161impl core::ops::AddAssign for AllWire {
162 fn add_assign(&mut self, rhs: Self) {
163 let (p, q) = (self.modulus(), rhs.modulus());
164 match (self, rhs) {
165 (Self::Mod2(x), Self::Mod2(y)) => *x += y,
166 (Self::Mod3(x), Self::Mod3(y)) => *x += y,
167 (Self::ModN(x), Self::ModN(y)) => *x += y,
168 _ => panic!("unequal moduli: {p} != {q}"),
169 }
170 }
171}
172
173impl core::ops::Sub for AllWire {
174 type Output = Self;
175
176 fn sub(self, rhs: Self) -> Self::Output {
177 self + -rhs
178 }
179}
180
181impl core::ops::SubAssign for AllWire {
182 fn sub_assign(&mut self, rhs: Self) {
183 *self = self.clone() - rhs;
184 }
185}
186
187impl core::ops::Neg for AllWire {
188 type Output = Self;
189
190 fn neg(self) -> Self::Output {
191 match self {
192 Self::Mod2(x) => Self::Mod2(-x),
193 Self::Mod3(x) => Self::Mod3(-x),
194 Self::ModN(x) => Self::ModN(-x),
195 }
196 }
197}
198
199impl core::ops::Mul<u16> for AllWire {
200 type Output = Self;
201
202 fn mul(self, rhs: u16) -> Self::Output {
203 match self {
204 Self::Mod2(x) => Self::Mod2(x * rhs),
205 Self::Mod3(x) => Self::Mod3(x * rhs),
206 Self::ModN(x) => Self::ModN(x * rhs),
207 }
208 }
209}
210
211impl core::ops::MulAssign<u16> for AllWire {
212 fn mul_assign(&mut self, rhs: u16) {
213 match self {
214 Self::Mod2(x) => {
215 *x *= rhs;
216 }
217 Self::Mod3(x) => {
218 *x *= rhs;
219 }
220 Self::ModN(x) => {
221 *x *= rhs;
222 }
223 };
224 }
225}
226
227impl WireLabel for AllWire {
228 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
229 match q {
230 2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
231 3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
232 _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
233 }
234 }
235
236 fn digits(&self) -> Vec<u16> {
237 match &self {
238 AllWire::Mod2(x) => x.digits(),
239 AllWire::Mod3(x) => x.digits(),
240 AllWire::ModN(x) => x.digits(),
241 }
242 }
243
244 fn to_repr(&self) -> U8x16 {
245 match &self {
246 AllWire::Mod2(x) => x.to_repr(),
247 AllWire::Mod3(x) => x.to_repr(),
248 AllWire::ModN(x) => x.to_repr(),
249 }
250 }
251 fn color(&self) -> u16 {
252 match &self {
253 AllWire::Mod2(x) => x.color(),
254 AllWire::Mod3(x) => x.color(),
255 AllWire::ModN(x) => x.color(),
256 }
257 }
258 fn from_repr(inp: U8x16, q: u16) -> Self {
259 match q {
260 2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
261 3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
262 _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
263 }
264 }
265
266 fn zero(q: u16) -> Self {
267 match q {
268 2 => AllWire::Mod2(WireMod2::zero(q)),
269 3 => AllWire::Mod3(WireMod3::zero(q)),
270 _ => AllWire::ModN(WireModQ::zero(q)),
271 }
272 }
273
274 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
275 match q {
276 2 => AllWire::Mod2(WireMod2::rand(rng, q)),
277 3 => AllWire::Mod3(WireMod3::rand(rng, q)),
278 _ => AllWire::ModN(WireModQ::rand(rng, q)),
279 }
280 }
281
282 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
283 if q == 3 {
284 AllWire::Mod3(WireMod3::encode_block_mod3(hash))
285 } else {
286 Self::from_repr(hash, q)
287 }
288 }
289}
290fn _unrank(inp: u128, q: u16) -> Vec<u16> {
291 let mut x = inp;
292 let ndigits = util::digits_per_u128(q);
293 let npaths_tab = npaths_tab::lookup(q);
294 x %= npaths_tab[ndigits - 1] * q as u128;
295
296 let mut ds = vec![0; ndigits];
297 for i in (0..ndigits).rev() {
298 let npaths = npaths_tab[i];
299
300 if q <= 23 {
301 let mut acc = 0;
303 for j in 0..q {
304 acc += npaths;
305 if acc > x {
306 x -= acc - npaths;
307 ds[i] = j;
308 break;
309 }
310 }
311 } else {
312 let d = x / npaths;
314 ds[i] = d as u16;
315 x -= d * npaths;
316 }
317 }
339 ds
340}
341
342impl ArithmeticWire for AllWire {}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::util::RngExt;
348 use itertools::Itertools;
349 use rand::thread_rng;
350
351 #[test]
352 fn packing() {
353 let rng = &mut thread_rng();
354 for q in 2..256 {
355 for _ in 0..1000 {
356 let w = AllWire::rand(rng, q);
357 assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
358 }
359 }
360 }
361
362 #[test]
363 fn base_conversion_lookup_method() {
364 let rng = &mut thread_rng();
365 for _ in 0..1000 {
366 let q = 5 + (rng.gen_u16() % 110);
367 let x = rng.gen_u128();
368 let w = AllWire::from_repr(U8x16::from(x), q);
369 let should_be = util::as_base_q_u128(x, q);
370 assert_eq!(w.digits(), should_be, "x={} q={}", x, q);
371 }
372 }
373
374 #[test]
375 fn hash() {
376 let mut rng = thread_rng();
377 for _ in 0..100 {
378 let q = 2 + (rng.gen_u16() % 110);
379 let x = AllWire::rand(&mut rng, q);
380 let y = x.hashback(1u128, q);
381 assert!(x != y);
382 match y {
383 AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
384 AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
385 AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
386 }
387 }
388 }
389
390 #[test]
391 fn negation() {
392 let rng = &mut thread_rng();
393 for _ in 0..1000 {
394 let q = rng.gen_modulus();
395 let x = AllWire::rand(rng, q);
396 let xneg = -x.clone();
397 if q != 2 {
398 assert!(x != xneg);
399 }
400 let y = -xneg;
401 assert_eq!(x, y);
402 }
403 }
404
405 #[test]
406 fn zero() {
407 let mut rng = thread_rng();
408 for _ in 0..1000 {
409 let q = 3 + (rng.gen_u16() % 110);
410 let z = AllWire::zero(q);
411 let ds = z.digits();
412 assert_eq!(ds, vec![0; ds.len()], "q={}", q);
413 }
414 }
415
416 #[test]
417 fn subzero() {
418 let mut rng = thread_rng();
419 for _ in 0..1000 {
420 let q = rng.gen_modulus();
421 let x = AllWire::rand(&mut rng, q);
422 let z = AllWire::zero(q);
423 assert_eq!(x.clone() - x, z);
424 }
425 }
426
427 #[test]
428 fn pluszero() {
429 let mut rng = thread_rng();
430 for _ in 0..1000 {
431 let q = rng.gen_modulus();
432 let x = AllWire::rand(&mut rng, q);
433 assert_eq!(x.clone() + AllWire::zero(q), x);
434 }
435 }
436
437 #[test]
438 #[allow(clippy::erasing_op)]
439 fn arithmetic() {
440 let mut rng = thread_rng();
441 for _ in 0..1024 {
442 let q = rng.gen_modulus();
443 let x = AllWire::rand(&mut rng, q);
444 let y = AllWire::rand(&mut rng, q);
445 assert_eq!(x.clone() * 0, AllWire::zero(q));
446 assert_eq!(x.clone() * q, AllWire::zero(q));
447 assert_eq!(x.clone() + x.clone(), x.clone() * 2);
448 assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
449 assert_eq!(-(-x.clone()), x);
450 if q == 2 {
451 assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
452 } else {
453 assert_eq!(x.clone() + -x.clone(), AllWire::zero(q), "q={}", q);
454 assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
455 }
456 let mut w = x.clone();
457 let z = w.clone() + y.clone();
458 w += y;
459 assert_eq!(w, z);
460
461 w = x.clone();
462 w *= 2;
463 assert_eq!(x.clone() + x.clone(), w);
464
465 w = x.clone();
466 w = -w;
467 assert_eq!(-x, w);
468 }
469 }
470
471 #[test]
472 fn ndigits_correct() {
473 let mut rng = thread_rng();
474 for _ in 0..1024 {
475 let q = rng.gen_modulus();
476 let x = AllWire::rand(&mut rng, q);
477 assert_eq!(x.digits().len(), util::digits_per_u128(q));
478 }
479 }
480
481 #[test]
482 fn parallel_hash() {
483 let n = 1000;
484 let mut rng = thread_rng();
485 let q = rng.gen_modulus();
486 let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
487
488 let mut handles = Vec::new();
489 for w in ws.iter() {
490 let w_ = w.clone();
491 let h = std::thread::spawn(move || w_.hash(0u128));
492 handles.push(h);
493 }
494 let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
495
496 let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
497
498 assert_eq!(hashes, should_be);
499 }
500
501 #[cfg(feature = "serde")]
502 #[test]
503 fn test_serialize_allwire() {
504 let mut rng = thread_rng();
505 for q in 2..16 {
506 let w = AllWire::rand(&mut rng, q);
507 let serialized = serde_json::to_string(&w).unwrap();
508
509 let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
510
511 assert_eq!(w, deserialized);
512 }
513 }
514}