1use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12 U8x16,
13 array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27 ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29 let batch = wires.array_map(|x| x.to_repr());
30 TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33pub trait ArithmeticWire: Clone {}
36
37pub trait WireLabel:
42 Clone
43 + core::fmt::Debug
44 + HasModulus
45 + core::ops::Add<Output = Self>
46 + core::ops::AddAssign
47 + core::ops::Sub<Output = Self>
48 + core::ops::SubAssign
49 + core::ops::Neg<Output = Self>
50 + core::ops::Mul<u16, Output = Self>
51 + core::ops::MulAssign<u16>
52{
53 fn to_repr(&self) -> U8x16;
55
56 fn color(&self) -> u16;
58
59 fn from_repr(inp: U8x16, q: u16) -> Self;
66
67 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
73
74 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
80
81 fn hash_to_mod(hash: U8x16, q: u16) -> Self;
90
91 fn hashback(&self, tweak: u128, q: u16) -> Self {
102 let hash = self.hash(tweak);
103 Self::hash_to_mod(hash, q)
104 }
105
106 fn hash(&self, tweak: u128) -> U8x16 {
108 TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
109 }
110
111 fn constant<RNG: CryptoRng + RngCore>(
114 x: u16,
115 q: u16,
116 delta: &Self,
117 rng: &mut RNG,
118 ) -> (Self, Self) {
119 let zero = Self::rand(rng, q);
120 let wire = zero.clone() + delta.clone() * x;
121 (zero, wire)
122 }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127pub enum AllWire {
129 Mod2(WireMod2),
131 Mod3(WireMod3),
133 ModN(WireModQ),
135}
136
137impl HasModulus for AllWire {
138 fn modulus(&self) -> u16 {
139 match &self {
140 AllWire::Mod2(x) => x.modulus(),
141 AllWire::Mod3(x) => x.modulus(),
142 AllWire::ModN(x) => x.modulus(),
143 }
144 }
145}
146
147impl core::ops::Add for AllWire {
148 type Output = Self;
149
150 fn add(self, rhs: Self) -> Self::Output {
151 let (p, q) = (self.modulus(), rhs.modulus());
152 match (self, rhs) {
153 (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
154 (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
155 (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
156 _ => panic!("unequal moduli: {p} != {q}"),
157 }
158 }
159}
160
161impl core::ops::AddAssign for AllWire {
162 fn add_assign(&mut self, rhs: Self) {
163 let (p, q) = (self.modulus(), rhs.modulus());
164 match (self, rhs) {
165 (Self::Mod2(x), Self::Mod2(y)) => *x += y,
166 (Self::Mod3(x), Self::Mod3(y)) => *x += y,
167 (Self::ModN(x), Self::ModN(y)) => *x += y,
168 _ => panic!("unequal moduli: {p} != {q}"),
169 }
170 }
171}
172
173impl core::ops::Sub for AllWire {
174 type Output = Self;
175
176 fn sub(self, rhs: Self) -> Self::Output {
177 self + -rhs
178 }
179}
180
181impl core::ops::SubAssign for AllWire {
182 fn sub_assign(&mut self, rhs: Self) {
183 *self = self.clone() - rhs;
184 }
185}
186
187impl core::ops::Neg for AllWire {
188 type Output = Self;
189
190 fn neg(self) -> Self::Output {
191 match self {
192 Self::Mod2(x) => Self::Mod2(-x),
193 Self::Mod3(x) => Self::Mod3(-x),
194 Self::ModN(x) => Self::ModN(-x),
195 }
196 }
197}
198
199impl core::ops::Mul<u16> for AllWire {
200 type Output = Self;
201
202 fn mul(self, rhs: u16) -> Self::Output {
203 match self {
204 Self::Mod2(x) => Self::Mod2(x * rhs),
205 Self::Mod3(x) => Self::Mod3(x * rhs),
206 Self::ModN(x) => Self::ModN(x * rhs),
207 }
208 }
209}
210
211impl core::ops::MulAssign<u16> for AllWire {
212 fn mul_assign(&mut self, rhs: u16) {
213 match self {
214 Self::Mod2(x) => {
215 *x *= rhs;
216 }
217 Self::Mod3(x) => {
218 *x *= rhs;
219 }
220 Self::ModN(x) => {
221 *x *= rhs;
222 }
223 };
224 }
225}
226
227impl WireLabel for AllWire {
228 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
229 match q {
230 2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
231 3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
232 _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
233 }
234 }
235
236 fn to_repr(&self) -> U8x16 {
237 match &self {
238 AllWire::Mod2(x) => x.to_repr(),
239 AllWire::Mod3(x) => x.to_repr(),
240 AllWire::ModN(x) => x.to_repr(),
241 }
242 }
243 fn color(&self) -> u16 {
244 match &self {
245 AllWire::Mod2(x) => x.color(),
246 AllWire::Mod3(x) => x.color(),
247 AllWire::ModN(x) => x.color(),
248 }
249 }
250 fn from_repr(inp: U8x16, q: u16) -> Self {
251 match q {
252 2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
253 3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
254 _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
255 }
256 }
257
258 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
259 match q {
260 2 => AllWire::Mod2(WireMod2::rand(rng, q)),
261 3 => AllWire::Mod3(WireMod3::rand(rng, q)),
262 _ => AllWire::ModN(WireModQ::rand(rng, q)),
263 }
264 }
265
266 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
267 if q == 3 {
268 AllWire::Mod3(WireMod3::encode_block_mod3(hash))
269 } else {
270 Self::from_repr(hash, q)
271 }
272 }
273}
274fn _unrank(inp: u128, q: u16) -> Vec<u16> {
275 let mut x = inp;
276 let ndigits = util::digits_per_u128(q);
277 let npaths_tab = npaths_tab::lookup(q);
278 x %= npaths_tab[ndigits - 1] * q as u128;
279
280 let mut ds = vec![0; ndigits];
281 for i in (0..ndigits).rev() {
282 let npaths = npaths_tab[i];
283
284 if q <= 23 {
285 let mut acc = 0;
287 for j in 0..q {
288 acc += npaths;
289 if acc > x {
290 x -= acc - npaths;
291 ds[i] = j;
292 break;
293 }
294 }
295 } else {
296 let d = x / npaths;
298 ds[i] = d as u16;
299 x -= d * npaths;
300 }
301 }
323 ds
324}
325
326impl ArithmeticWire for AllWire {}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::util::RngExt;
332 use itertools::Itertools;
333 use rand::thread_rng;
334
335 #[test]
336 fn packing() {
337 let rng = &mut thread_rng();
338 for q in 2..256 {
339 for _ in 0..1000 {
340 let w = AllWire::rand(rng, q);
341 assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
342 }
343 }
344 }
345
346 #[test]
347 fn base_conversion_lookup_method() {
348 let rng = &mut thread_rng();
349 for _ in 0..1000 {
350 let q = 5 + (rng.gen_u16() % 110);
351 let x = rng.gen_u128();
352 let w = WireModQ::from_repr(U8x16::from(x), q);
353 let should_be = util::as_base_q_u128(x, q);
354 assert_eq!(w.ds, should_be, "x={} q={}", x, q);
355 }
356 }
357
358 #[test]
359 fn hash() {
360 let mut rng = thread_rng();
361 for _ in 0..100 {
362 let q = 2 + (rng.gen_u16() % 110);
363 let x = AllWire::rand(&mut rng, q);
364 let y = x.hashback(1u128, q);
365 assert!(x != y);
366 match y {
367 AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
368 AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
369 AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
370 }
371 }
372 }
373
374 #[test]
375 fn negation() {
376 let rng = &mut thread_rng();
377 for _ in 0..1000 {
378 let q = rng.gen_modulus();
379 let x = AllWire::rand(rng, q);
380 let xneg = -x.clone();
381 if q != 2 {
382 assert!(x != xneg);
383 }
384 let y = -xneg;
385 assert_eq!(x, y);
386 }
387 }
388
389 #[test]
390 #[allow(clippy::erasing_op)]
391 fn arithmetic() {
392 let mut rng = thread_rng();
393 for _ in 0..1024 {
394 let q = rng.gen_modulus();
395 let x = AllWire::rand(&mut rng, q);
396 let y = AllWire::rand(&mut rng, q);
397 assert_eq!(x.clone() * 0, x.clone() - x.clone());
398 assert_eq!(x.clone() * q, x.clone() - x.clone());
399 assert_eq!(x.clone() + x.clone(), x.clone() * 2);
400 assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
401 assert_eq!(-(-x.clone()), x);
402 if q == 2 {
403 assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
404 } else {
405 assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
406 assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
407 }
408 let mut w = x.clone();
409 let z = w.clone() + y.clone();
410 w += y;
411 assert_eq!(w, z);
412
413 w = x.clone();
414 w *= 2;
415 assert_eq!(x.clone() + x.clone(), w);
416
417 w = x.clone();
418 w = -w;
419 assert_eq!(-x, w);
420 }
421 }
422
423 #[test]
424 fn ndigits_correct() {
425 let mut rng = thread_rng();
426 for _ in 0..1024 {
427 let q = rng.gen_modulus();
428 let x = WireModQ::rand(&mut rng, q);
429 assert_eq!(x.ds.len(), util::digits_per_u128(q));
430 }
431 }
432
433 #[test]
434 fn parallel_hash() {
435 let n = 1000;
436 let mut rng = thread_rng();
437 let q = rng.gen_modulus();
438 let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
439
440 let mut handles = Vec::new();
441 for w in ws.iter() {
442 let w_ = w.clone();
443 let h = std::thread::spawn(move || w_.hash(0u128));
444 handles.push(h);
445 }
446 let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
447
448 let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
449
450 assert_eq!(hashes, should_be);
451 }
452
453 #[cfg(feature = "serde")]
454 #[test]
455 fn test_serialize_allwire() {
456 let mut rng = thread_rng();
457 for q in 2..16 {
458 let w = AllWire::rand(&mut rng, q);
459 let serialized = serde_json::to_string(&w).unwrap();
460
461 let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
462
463 assert_eq!(w, deserialized);
464 }
465 }
466}