1use crate::{fancy::HasModulus, util};
9use rand::{CryptoRng, Rng, RngCore};
10use swanky_cr_hash::TweakableCircularCorrelationRobustHash;
11use vectoreyes::{
12 U8x16,
13 array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
14};
15
16mod mod2;
17pub use mod2::WireMod2;
18mod mod3;
19pub use mod3::WireMod3;
20mod modq;
21pub use modq::WireModQ;
22mod npaths_tab;
23
24pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: u128) -> [U8x16; Q]
26where
27 ArrayUnrolledOps: UnrollableArraySize<Q>,
28{
29 let batch = wires.array_map(|x| x.to_repr());
30 TweakableCircularCorrelationRobustHash::fixed_key().hash_many(batch, tweak)
31}
32
33pub trait ArithmeticWire: Clone {}
36
37pub trait WireLabel:
42 Clone
43 + HasModulus
44 + core::ops::Add<Output = Self>
45 + core::ops::AddAssign
46 + core::ops::Sub<Output = Self>
47 + core::ops::SubAssign
48 + core::ops::Neg<Output = Self>
49 + core::ops::Mul<u16, Output = Self>
50 + core::ops::MulAssign<u16>
51{
52 fn to_repr(&self) -> U8x16;
54
55 fn color(&self) -> u16;
57
58 fn from_repr(inp: U8x16, q: u16) -> Self;
65
66 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
72
73 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
79
80 fn hash_to_mod(hash: U8x16, q: u16) -> Self;
89
90 fn hashback(&self, tweak: u128, q: u16) -> Self {
101 let hash = self.hash(tweak);
102 Self::hash_to_mod(hash, q)
103 }
104
105 fn hash(&self, tweak: u128) -> U8x16 {
107 TweakableCircularCorrelationRobustHash::fixed_key().hash(self.to_repr(), tweak)
108 }
109
110 fn constant<RNG: CryptoRng + RngCore>(
113 x: u16,
114 q: u16,
115 delta: &Self,
116 rng: &mut RNG,
117 ) -> (Self, Self) {
118 let zero = Self::rand(rng, q);
119 let wire = zero.clone() + delta.clone() * x;
120 (zero, wire)
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub enum AllWire {
128 Mod2(WireMod2),
130 Mod3(WireMod3),
132 ModN(WireModQ),
134}
135
136impl HasModulus for AllWire {
137 fn modulus(&self) -> u16 {
138 match &self {
139 AllWire::Mod2(x) => x.modulus(),
140 AllWire::Mod3(x) => x.modulus(),
141 AllWire::ModN(x) => x.modulus(),
142 }
143 }
144}
145
146impl core::ops::Add for AllWire {
147 type Output = Self;
148
149 fn add(self, rhs: Self) -> Self::Output {
150 let (p, q) = (self.modulus(), rhs.modulus());
151 match (self, rhs) {
152 (Self::Mod2(x), Self::Mod2(y)) => Self::Mod2(x + y),
153 (Self::Mod3(x), Self::Mod3(y)) => Self::Mod3(x + y),
154 (Self::ModN(x), Self::ModN(y)) => Self::ModN(x + y),
155 _ => panic!("unequal moduli: {p} != {q}"),
156 }
157 }
158}
159
160impl core::ops::AddAssign for AllWire {
161 fn add_assign(&mut self, rhs: Self) {
162 let (p, q) = (self.modulus(), rhs.modulus());
163 match (self, rhs) {
164 (Self::Mod2(x), Self::Mod2(y)) => *x += y,
165 (Self::Mod3(x), Self::Mod3(y)) => *x += y,
166 (Self::ModN(x), Self::ModN(y)) => *x += y,
167 _ => panic!("unequal moduli: {p} != {q}"),
168 }
169 }
170}
171
172impl core::ops::Sub for AllWire {
173 type Output = Self;
174
175 fn sub(self, rhs: Self) -> Self::Output {
176 self + -rhs
177 }
178}
179
180impl core::ops::SubAssign for AllWire {
181 fn sub_assign(&mut self, rhs: Self) {
182 *self = self.clone() - rhs;
183 }
184}
185
186impl core::ops::Neg for AllWire {
187 type Output = Self;
188
189 fn neg(self) -> Self::Output {
190 match self {
191 Self::Mod2(x) => Self::Mod2(-x),
192 Self::Mod3(x) => Self::Mod3(-x),
193 Self::ModN(x) => Self::ModN(-x),
194 }
195 }
196}
197
198impl core::ops::Mul<u16> for AllWire {
199 type Output = Self;
200
201 fn mul(self, rhs: u16) -> Self::Output {
202 match self {
203 Self::Mod2(x) => Self::Mod2(x * rhs),
204 Self::Mod3(x) => Self::Mod3(x * rhs),
205 Self::ModN(x) => Self::ModN(x * rhs),
206 }
207 }
208}
209
210impl core::ops::MulAssign<u16> for AllWire {
211 fn mul_assign(&mut self, rhs: u16) {
212 match self {
213 Self::Mod2(x) => {
214 *x *= rhs;
215 }
216 Self::Mod3(x) => {
217 *x *= rhs;
218 }
219 Self::ModN(x) => {
220 *x *= rhs;
221 }
222 };
223 }
224}
225
226impl WireLabel for AllWire {
227 fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
228 match q {
229 2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
230 3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
231 _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
232 }
233 }
234
235 fn to_repr(&self) -> U8x16 {
236 match &self {
237 AllWire::Mod2(x) => x.to_repr(),
238 AllWire::Mod3(x) => x.to_repr(),
239 AllWire::ModN(x) => x.to_repr(),
240 }
241 }
242 fn color(&self) -> u16 {
243 match &self {
244 AllWire::Mod2(x) => x.color(),
245 AllWire::Mod3(x) => x.color(),
246 AllWire::ModN(x) => x.color(),
247 }
248 }
249 fn from_repr(inp: U8x16, q: u16) -> Self {
250 match q {
251 2 => AllWire::Mod2(WireMod2::from_repr(inp, q)),
252 3 => AllWire::Mod3(WireMod3::from_repr(inp, q)),
253 _ => AllWire::ModN(WireModQ::from_repr(inp, q)),
254 }
255 }
256
257 fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
258 match q {
259 2 => AllWire::Mod2(WireMod2::rand(rng, q)),
260 3 => AllWire::Mod3(WireMod3::rand(rng, q)),
261 _ => AllWire::ModN(WireModQ::rand(rng, q)),
262 }
263 }
264
265 fn hash_to_mod(hash: U8x16, q: u16) -> Self {
266 if q == 3 {
267 AllWire::Mod3(WireMod3::encode_block_mod3(hash))
268 } else {
269 Self::from_repr(hash, q)
270 }
271 }
272}
273fn _unrank(inp: u128, q: u16) -> Vec<u16> {
274 let mut x = inp;
275 let ndigits = util::digits_per_u128(q);
276 let npaths_tab = npaths_tab::lookup(q);
277 x %= npaths_tab[ndigits - 1] * q as u128;
278
279 let mut ds = vec![0; ndigits];
280 for i in (0..ndigits).rev() {
281 let npaths = npaths_tab[i];
282
283 if q <= 23 {
284 let mut acc = 0;
286 for j in 0..q {
287 acc += npaths;
288 if acc > x {
289 x -= acc - npaths;
290 ds[i] = j;
291 break;
292 }
293 }
294 } else {
295 let d = x / npaths;
297 ds[i] = d as u16;
298 x -= d * npaths;
299 }
300 }
322 ds
323}
324
325impl ArithmeticWire for AllWire {}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::util::RngExt;
331 use itertools::Itertools;
332 use rand::thread_rng;
333
334 #[test]
335 fn packing() {
336 let rng = &mut thread_rng();
337 for q in 2..256 {
338 for _ in 0..1000 {
339 let w = AllWire::rand(rng, q);
340 assert_eq!(w, AllWire::from_repr(w.to_repr(), q));
341 }
342 }
343 }
344
345 #[test]
346 fn base_conversion_lookup_method() {
347 let rng = &mut thread_rng();
348 for _ in 0..1000 {
349 let q = 5 + (rng.gen_u16() % 110);
350 let x = rng.gen_u128();
351 let w = WireModQ::from_repr(U8x16::from(x), q);
352 let should_be = util::as_base_q_u128(x, q);
353 assert_eq!(w.ds, should_be, "x={} q={}", x, q);
354 }
355 }
356
357 #[test]
358 fn hash() {
359 let mut rng = thread_rng();
360 for _ in 0..100 {
361 let q = 2 + (rng.gen_u16() % 110);
362 let x = AllWire::rand(&mut rng, q);
363 let y = x.hashback(1u128, q);
364 assert!(x != y);
365 match y {
366 AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
367 AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
368 AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
369 }
370 }
371 }
372
373 #[test]
374 fn negation() {
375 let rng = &mut thread_rng();
376 for _ in 0..1000 {
377 let q = rng.gen_modulus();
378 let x = AllWire::rand(rng, q);
379 let xneg = -x.clone();
380 if q != 2 {
381 assert!(x != xneg);
382 }
383 let y = -xneg;
384 assert_eq!(x, y);
385 }
386 }
387
388 #[test]
389 #[allow(clippy::erasing_op)]
390 fn arithmetic() {
391 let mut rng = thread_rng();
392 for _ in 0..1024 {
393 let q = rng.gen_modulus();
394 let x = AllWire::rand(&mut rng, q);
395 let y = AllWire::rand(&mut rng, q);
396 assert_eq!(x.clone() * 0, x.clone() - x.clone());
397 assert_eq!(x.clone() * q, x.clone() - x.clone());
398 assert_eq!(x.clone() + x.clone(), x.clone() * 2);
399 assert_eq!(x.clone() + x.clone() + x.clone(), x.clone() * 3);
400 assert_eq!(-(-x.clone()), x);
401 if q == 2 {
402 assert_eq!(x.clone() + y.clone(), x.clone() - y.clone());
403 } else {
404 assert_eq!(x.clone() + -x.clone(), x.clone() - x.clone());
405 assert_eq!(x.clone() + -y.clone(), x.clone() - y.clone());
406 }
407 let mut w = x.clone();
408 let z = w.clone() + y.clone();
409 w += y;
410 assert_eq!(w, z);
411
412 w = x.clone();
413 w *= 2;
414 assert_eq!(x.clone() + x.clone(), w);
415
416 w = x.clone();
417 w = -w;
418 assert_eq!(-x, w);
419 }
420 }
421
422 #[test]
423 fn ndigits_correct() {
424 let mut rng = thread_rng();
425 for _ in 0..1024 {
426 let q = rng.gen_modulus();
427 let x = WireModQ::rand(&mut rng, q);
428 assert_eq!(x.ds.len(), util::digits_per_u128(q));
429 }
430 }
431
432 #[test]
433 fn parallel_hash() {
434 let n = 1000;
435 let mut rng = thread_rng();
436 let q = rng.gen_modulus();
437 let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
438
439 let mut handles = Vec::new();
440 for w in ws.iter() {
441 let w_ = w.clone();
442 let h = std::thread::spawn(move || w_.hash(0u128));
443 handles.push(h);
444 }
445 let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
446
447 let should_be = ws.iter().map(|w| w.hash(0u128)).collect_vec();
448
449 assert_eq!(hashes, should_be);
450 }
451
452 #[cfg(feature = "serde")]
453 #[test]
454 fn test_serialize_allwire() {
455 let mut rng = thread_rng();
456 for q in 2..16 {
457 let w = AllWire::rand(&mut rng, q);
458 let serialized = serde_json::to_string(&w).unwrap();
459
460 let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
461
462 assert_eq!(w, deserialized);
463 }
464 }
465}