1use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12 U8x16,
13 array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27 ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29 let batch = wires.array_map(|x| x.to_repr());
30 TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33pub trait ArithmeticWire: Clone {}
36
37pub trait WireLabel:
42 Clone
43 + HasModulus
44 + core::ops::Add<Output = Self>
45 + core::ops::AddAssign
46 + core::ops::Sub<Output = Self>
47 + core::ops::SubAssign
48 + core::ops::Neg<Output = Self>
49 + core::ops::Mul<u16, Output = Self>
50 + core::ops::MulAssign<u16>
51{
52 fn digits(&self) -> Vec<u16>;
54
55 fn to_repr(&self) -> U8x16;
57
58 fn color(&self) -> u16;
60
61 fn from_repr(inp: U8x16, q: u16) -> Self;
68
69 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
75
76 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
82
83 fn hash_to_mod(hash: U8x16, q: u16) -> Self;
92
93 fn hashback(&self, tweak: u128, q: u16) -> Self {
104 let hash = self.hash(tweak);
105 Self::hash_to_mod(hash, q)
106 }
107
108 fn hash(&self, tweak: u128) -> U8x16 {
110 TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
111 }
112
113 fn constant<RNG: CryptoRng + RngCore>(
116 x: u16,
117 q: u16,
118 delta: &Self,
119 rng: &mut RNG,
120 ) -> (Self, Self) {
121 let zero = Self::rand(rng, q);
122 let wire = zero.clone() + delta.clone() * x;
123 (zero, wire)
124 }
125}
126
127#[derive(Debug, Clone, PartialEq)]
128#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129pub enum AllWire {
131 Mod2(WireMod2),
133 Mod3(WireMod3),
135 ModN(WireModQ),
137}
138
139impl HasModulus for AllWire {
140 fn modulus(&self) -> u16 {
141 match &self {
142 AllWire::Mod2(x) => x.modulus(),
143 AllWire::Mod3(x) => x.modulus(),
144 AllWire::ModN(x) => x.modulus(),
145 }
146 }
147}
148
149impl core::ops::Add for AllWire {
150 type Output = Self;
151
152 fn add(self, rhs: Self) -> Self::Output {
153 let (p, q) = (self.modulus(), rhs.modulus());
154 match (self, rhs) {
155 (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
156 (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
157 (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
158 _ => panic!("unequal moduli: {p} != {q}"),
159 }
160 }
161}
162
163impl core::ops::AddAssign for AllWire {
164 fn add_assign(&mut self, rhs: Self) {
165 let (p, q) = (self.modulus(), rhs.modulus());
166 match (self, rhs) {
167 (Self::Mod2(x), Self::Mod2(y)) => *x += y,
168 (Self::Mod3(x), Self::Mod3(y)) => *x += y,
169 (Self::ModN(x), Self::ModN(y)) => *x += y,
170 _ => panic!("unequal moduli: {p} != {q}"),
171 }
172 }
173}
174
175impl core::ops::Sub for AllWire {
176 type Output = Self;
177
178 fn sub(self, rhs: Self) -> Self::Output {
179 self + -rhs
180 }
181}
182
183impl core::ops::SubAssign for AllWire {
184 fn sub_assign(&mut self, rhs: Self) {
185 *self = self.clone() - rhs;
186 }
187}
188
189impl core::ops::Neg for AllWire {
190 type Output = Self;
191
192 fn neg(self) -> Self::Output {
193 match self {
194 Self::Mod2(x) => Self::Mod2(-x),
195 Self::Mod3(x) => Self::Mod3(-x),
196 Self::ModN(x) => Self::ModN(-x),
197 }
198 }
199}
200
201impl core::ops::Mul<u16> for AllWire {
202 type Output = Self;
203
204 fn mul(self, rhs: u16) -> Self::Output {
205 match self {
206 Self::Mod2(x) => Self::Mod2(x * rhs),
207 Self::Mod3(x) => Self::Mod3(x * rhs),
208 Self::ModN(x) => Self::ModN(x * rhs),
209 }
210 }
211}
212
213impl core::ops::MulAssign<u16> for AllWire {
214 fn mul_assign(&mut self, rhs: u16) {
215 match self {
216 Self::Mod2(x) => {
217 *x *= rhs;
218 }
219 Self::Mod3(x) => {
220 *x *= rhs;
221 }
222 Self::ModN(x) => {
223 *x *= rhs;
224 }
225 };
226 }
227}
228
229impl WireLabel for AllWire {
230 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
231 match q {
232 2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
233 3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
234 _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
235 }
236 }
237
238 fn digits(&self) -> Vec<u16> {
239 match &self {
240 AllWire::Mod2(x) => x.digits(),
241 AllWire::Mod3(x) => x.digits(),
242 AllWire::ModN(x) => x.digits(),
243 }
244 }
245
246 fn to_repr(&self) -> U8x16 {
247 match &self {
248 AllWire::Mod2(x) => x.to_repr(),
249 AllWire::Mod3(x) => x.to_repr(),
250 AllWire::ModN(x) => x.to_repr(),
251 }
252 }
253 fn color(&self) -> u16 {
254 match &self {
255 AllWire::Mod2(x) => x.color(),
256 AllWire::Mod3(x) => x.color(),
257 AllWire::ModN(x) => x.color(),
258 }
259 }
260 fn from_repr(inp: U8x16, q: u16) -> Self {
261 match q {
262 2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
263 3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
264 _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
265 }
266 }
267
268 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
269 match q {
270 2 => AllWire::Mod2(WireMod2::rand(rng, q)),
271 3 => AllWire::Mod3(WireMod3::rand(rng, q)),
272 _ => AllWire::ModN(WireModQ::rand(rng, q)),
273 }
274 }
275
276 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
277 if q == 3 {
278 AllWire::Mod3(WireMod3::encode_block_mod3(hash))
279 } else {
280 Self::from_repr(hash, q)
281 }
282 }
283}
284fn _unrank(inp: u128, q: u16) -> Vec<u16> {
285 let mut x = inp;
286 let ndigits = util::digits_per_u128(q);
287 let npaths_tab = npaths_tab::lookup(q);
288 x %= npaths_tab[ndigits - 1] * q as u128;
289
290 let mut ds = vec![0; ndigits];
291 for i in (0..ndigits).rev() {
292 let npaths = npaths_tab[i];
293
294 if q <= 23 {
295 let mut acc = 0;
297 for j in 0..q {
298 acc += npaths;
299 if acc > x {
300 x -= acc - npaths;
301 ds[i] = j;
302 break;
303 }
304 }
305 } else {
306 let d = x / npaths;
308 ds[i] = d as u16;
309 x -= d * npaths;
310 }
311 }
333 ds
334}
335
336impl ArithmeticWire for AllWire {}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::util::RngExt;
342 use itertools::Itertools;
343 use rand::thread_rng;
344
345 #[test]
346 fn packing() {
347 let rng = &mut thread_rng();
348 for q in 2..256 {
349 for _ in 0..1000 {
350 let w = AllWire::rand(rng, q);
351 assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
352 }
353 }
354 }
355
356 #[test]
357 fn base_conversion_lookup_method() {
358 let rng = &mut thread_rng();
359 for _ in 0..1000 {
360 let q = 5 + (rng.gen_u16() % 110);
361 let x = rng.gen_u128();
362 let w = AllWire::from_repr(U8x16::from(x), q);
363 let should_be = util::as_base_q_u128(x, q);
364 assert_eq!(w.digits(), should_be, "x={} q={}", x, q);
365 }
366 }
367
368 #[test]
369 fn hash() {
370 let mut rng = thread_rng();
371 for _ in 0..100 {
372 let q = 2 + (rng.gen_u16() % 110);
373 let x = AllWire::rand(&mut rng, q);
374 let y = x.hashback(1u128, q);
375 assert!(x != y);
376 match y {
377 AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
378 AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
379 AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
380 }
381 }
382 }
383
384 #[test]
385 fn negation() {
386 let rng = &mut thread_rng();
387 for _ in 0..1000 {
388 let q = rng.gen_modulus();
389 let x = AllWire::rand(rng, q);
390 let xneg = -x.clone();
391 if q != 2 {
392 assert!(x != xneg);
393 }
394 let y = -xneg;
395 assert_eq!(x, y);
396 }
397 }
398
399 #[test]
400 #[allow(clippy::erasing_op)]
401 fn arithmetic() {
402 let mut rng = thread_rng();
403 for _ in 0..1024 {
404 let q = rng.gen_modulus();
405 let x = AllWire::rand(&mut rng, q);
406 let y = AllWire::rand(&mut rng, q);
407 assert_eq!(x.clone() * 0, x.clone() - x.clone());
408 assert_eq!(x.clone() * q, x.clone() - x.clone());
409 assert_eq!(x.clone() + x.clone(), x.clone() * 2);
410 assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
411 assert_eq!(-(-x.clone()), x);
412 if q == 2 {
413 assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
414 } else {
415 assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
416 assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
417 }
418 let mut w = x.clone();
419 let z = w.clone() + y.clone();
420 w += y;
421 assert_eq!(w, z);
422
423 w = x.clone();
424 w *= 2;
425 assert_eq!(x.clone() + x.clone(), w);
426
427 w = x.clone();
428 w = -w;
429 assert_eq!(-x, w);
430 }
431 }
432
433 #[test]
434 fn ndigits_correct() {
435 let mut rng = thread_rng();
436 for _ in 0..1024 {
437 let q = rng.gen_modulus();
438 let x = AllWire::rand(&mut rng, q);
439 assert_eq!(x.digits().len(), util::digits_per_u128(q));
440 }
441 }
442
443 #[test]
444 fn parallel_hash() {
445 let n = 1000;
446 let mut rng = thread_rng();
447 let q = rng.gen_modulus();
448 let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
449
450 let mut handles = Vec::new();
451 for w in ws.iter() {
452 let w_ = w.clone();
453 let h = std::thread::spawn(move || w_.hash(0u128));
454 handles.push(h);
455 }
456 let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
457
458 let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
459
460 assert_eq!(hashes, should_be);
461 }
462
463 #[cfg(feature = "serde")]
464 #[test]
465 fn test_serialize_allwire() {
466 let mut rng = thread_rng();
467 for q in 2..16 {
468 let w = AllWire::rand(&mut rng, q);
469 let serialized = serde_json::to_string(&w).unwrap();
470
471 let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
472
473 assert_eq!(w, deserialized);
474 }
475 }
476}