fancy_garbling/
wire.rs

1//! Low-level operations on wire-labels, the basic building block of garbled circuits.
2
3use crate::{fancy::HasModulus, util};
4use rand::{CryptoRng, Rng, RngCore};
5use subtle::ConditionallySelectable;
6use swanky_aes_hash::AesHash;
7use swanky_block::Block;
8use vectoreyes::{
9    SimdBase,
10    array_utils::{ArrayUnrolledExt, ArrayUnrolledOps, UnrollableArraySize},
11};
12
13#[cfg(feature = "serde")]
14use crate::errors::{ModQDeserializationError, WireDeserializationError};
15
16mod npaths_tab;
17
18#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20/// The core wire-label type.
21pub enum AllWire {
22    /// Modulo2 Wire
23    Mod2(WireMod2),
24
25    /// Modulo3 Wire
26    Mod3(WireMod3),
27
28    /// Modulo q Wire: 3 < q < 2^16
29    ModN(WireModQ),
30}
31
32/// Batch hashing of wires
33pub fn hash_wires<const Q: usize, W: WireLabel>(wires: [&W; Q], tweak: Block) -> [Block; Q]
34where
35    ArrayUnrolledOps: UnrollableArraySize<Q>,
36{
37    let batch = wires.array_map(|x| x.as_block());
38    AesHash::fixed_key().tccr_hash_many(tweak, batch)
39}
40
41/// Marker trait indicating an arithmetic wire
42pub trait ArithmeticWire: Clone {}
43
44/// Trait implementing a wire that can be used for secure computation
45/// via garbled circuits
46pub trait WireLabel: Clone + HasModulus {
47    /// Get the digits of the wire
48    fn digits(&self) -> Vec<u16>;
49
50    /// Pack the wire into a `Block`.
51    fn as_block(&self) -> Block;
52
53    /// Get the color digit of the wire.
54    fn color(&self) -> u16;
55
56    /// Add another wire digit-wise into this one. Assumes that both wires have
57    /// the same modulus.
58    fn plus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self;
59
60    /// Multiply each digit by a constant `c mod q`.
61    fn cmul_eq(&mut self, c: u16) -> &mut Self;
62
63    /// Negate all the digits mod q.
64    fn negate_eq(&mut self) -> &mut Self;
65
66    /// Pack the wire into a `Block`.
67    fn from_block(inp: Block, q: u16) -> Self;
68
69    /// The zero wire with modulus `q`
70    fn zero(q: u16) -> Self;
71
72    /// Get a random wire label mod `q`, with the first digit set to `1`
73    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self;
74
75    /// Get a random wire `mod q`.
76    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self;
77
78    /// Subroutine of hashback that converts the hash block into a valid wire of the given
79    /// modulus. Also useful when batching hashes ahead of time for later conversion.
80    fn hash_to_mod(hash: Block, q: u16) -> Self;
81
82    /// Compute the hash of this wire, converting the result back to a wire.
83    ///
84    /// Uses fixed-key AES.
85    fn hashback(&self, tweak: Block, q: u16) -> Self {
86        let hash = self.hash(tweak);
87        Self::hash_to_mod(hash, q)
88    }
89
90    /// Negate all the digits `mod q`, consuming it for chained computations.
91    fn negate_mov(mut self) -> Self {
92        self.negate_eq();
93        self
94    }
95
96    /// Multiply each digit by a constant `c mod q`, consuming it for chained computations.
97    fn cmul_mov(mut self, c: u16) -> Self {
98        self.cmul_eq(c);
99        self
100    }
101
102    /// Multiply each digit by a constant `c mod q`, returning a new wire.
103    fn cmul(&self, c: u16) -> Self {
104        self.clone().cmul_mov(c)
105    }
106
107    /// Add another wire into this one, consuming it for chained computations.
108    fn plus_mov(mut self, other: &Self) -> Self {
109        self.plus_eq(other);
110        self
111    }
112
113    /// Add two wires digit-wise, returning a new wire.
114    fn plus(&self, other: &Self) -> Self {
115        self.clone().plus_mov(other)
116    }
117
118    /// Negate all the digits `mod q`, returning a new wire.
119    fn negate(&self) -> Self {
120        self.clone().negate_mov()
121    }
122
123    /// Subtract a wire from this one, consuming it for chained computations.
124    fn minus_mov(mut self, other: &Self) -> Self {
125        self.minus_eq(other);
126        self
127    }
128
129    /// Subtract two wires, returning the result.
130    fn minus(&self, other: &Self) -> Self {
131        self.clone().minus_mov(other)
132    }
133
134    /// Subtract a wire from this one.
135    fn minus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self {
136        self.plus_eq(&other.negate());
137        self
138    }
139
140    /// Compute the hash of this wire.
141    ///
142    /// Uses fixed-key AES.
143    #[inline(never)]
144    fn hash(&self, tweak: Block) -> Block {
145        AesHash::fixed_key().tccr_hash(tweak, self.as_block())
146    }
147}
148
149/// Representation of a `mod-2` wire.
150#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
151#[derive(Debug, Clone, Copy, PartialEq, Default)]
152pub struct WireMod2 {
153    /// A 128-bit value.
154    val: Block,
155}
156
157impl ConditionallySelectable for WireMod2 {
158    fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
159        WireMod2::from_block(
160            Block::conditional_select(&a.as_block(), &b.as_block(), choice),
161            2,
162        )
163    }
164}
165
166/// Intermediate struct to deserialize WireMod3 to
167///
168/// Checks that both lsb and msb are not set before allowing to convert to WireMod3
169#[cfg(feature = "serde")]
170#[derive(serde::Deserialize)]
171struct UntrustedWireMod3 {
172    /// The least-significant bits of each `mod-3` element.
173    lsb: u64,
174    /// The most-significant bits of each `mod-3` element.
175    msb: u64,
176}
177
178#[cfg(feature = "serde")]
179impl TryFrom<UntrustedWireMod3> for WireMod3 {
180    type Error = WireDeserializationError;
181
182    fn try_from(wire: UntrustedWireMod3) -> Result<Self, Self::Error> {
183        if wire.lsb & wire.msb != 0 {
184            return Err(Self::Error::InvalidWireMod3);
185        }
186        Ok(WireMod3 {
187            lsb: wire.lsb,
188            msb: wire.msb,
189        })
190    }
191}
192
193/// Intermediate struct to deserialize WireModQ to
194///
195/// Checks that modulus is at least 2
196#[cfg(feature = "serde")]
197#[derive(serde::Deserialize)]
198struct UntrustedWireModQ {
199    /// The modulus of the wire label
200    q: u16, // Assuming mod can fit in u16
201    /// A list of `mod-q` digits.
202    ds: Vec<u16>,
203}
204
205#[cfg(feature = "serde")]
206impl TryFrom<UntrustedWireModQ> for WireModQ {
207    type Error = WireDeserializationError;
208
209    fn try_from(wire: UntrustedWireModQ) -> Result<Self, Self::Error> {
210        // Modulus must be at least 2
211        if wire.q < 2 {
212            return Err(Self::Error::InvalidWireModQ(
213                ModQDeserializationError::BadModulus(wire.q),
214            ));
215        }
216
217        // Check correct length and make sure all values are less than the modulus
218        let expected_len = util::digits_per_u128(wire.q);
219        let given_len = wire.ds.len();
220        if given_len != expected_len {
221            return Err(Self::Error::InvalidWireModQ(
222                ModQDeserializationError::InvalidDigitsLength {
223                    got: given_len,
224                    needed: expected_len,
225                },
226            ));
227        }
228        if let Some(i) = wire.ds.iter().position(|&x| x >= wire.q) {
229            return Err(Self::Error::InvalidWireModQ(
230                ModQDeserializationError::DigitTooLarge {
231                    digit: wire.ds[i],
232                    modulus: wire.q,
233                },
234            ));
235        }
236        Ok(WireModQ {
237            q: wire.q,
238            ds: wire.ds,
239        })
240    }
241}
242
243/// Representation of a `mod-3` wire.
244///
245/// We represent a `mod-3` wire by 64 `mod-3` elements. These elements are
246/// stored as follows: the least-significant bits of each element are stored
247/// in `lsb` and the most-significant bits of each element are stored in
248/// `msb`. This representation allows for efficient addition and
249/// multiplication as described here by the paper "Hardware Implementation
250/// of Finite Fields of Characteristic Three." D. Page, N.P. Smart. CHES
251/// 2002. Link:
252/// <https://link.springer.com/content/pdf/10.1007/3-540-36400-5_38.pdf>.
253#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
254#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireMod3"))]
255#[derive(Debug, Clone, Copy, PartialEq, Default)]
256pub struct WireMod3 {
257    /// The least-significant bits of each `mod-3` element.
258    lsb: u64,
259    /// The most-significant bits of each `mod-3` element.
260    msb: u64,
261}
262
263// Assuming mod can fit in u16
264/// Representation of a `mod-q` wire.
265///
266/// We represent a `mod-q` wire for `q > 3` by the modulus`q` alongside a
267/// list of `mod-q` digits.
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireModQ"))]
270#[derive(Debug, Clone, PartialEq, Default)]
271pub struct WireModQ {
272    /// The modulus of the wire label
273    q: u16,
274    /// A list of `mod-q` digits.
275    ds: Vec<u16>,
276}
277
278impl HasModulus for WireMod2 {
279    fn modulus(&self) -> u16 {
280        2
281    }
282}
283
284impl HasModulus for WireMod3 {
285    fn modulus(&self) -> u16 {
286        3
287    }
288}
289
290impl HasModulus for WireModQ {
291    fn modulus(&self) -> u16 {
292        self.q
293    }
294}
295
296impl HasModulus for AllWire {
297    fn modulus(&self) -> u16 {
298        match &self {
299            AllWire::Mod2(x) => x.modulus(),
300            AllWire::Mod3(x) => x.modulus(),
301            AllWire::ModN(x) => x.modulus(),
302        }
303    }
304}
305
306impl WireLabel for AllWire {
307    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
308        match q {
309            2 => AllWire::Mod2(WireMod2::rand_delta(rng, q)),
310            3 => AllWire::Mod3(WireMod3::rand_delta(rng, q)),
311            _ => AllWire::ModN(WireModQ::rand_delta(rng, q)),
312        }
313    }
314
315    fn digits(&self) -> Vec<u16> {
316        match &self {
317            AllWire::Mod2(x) => x.digits(),
318            AllWire::Mod3(x) => x.digits(),
319            AllWire::ModN(x) => x.digits(),
320        }
321    }
322
323    fn as_block(&self) -> Block {
324        match &self {
325            AllWire::Mod2(x) => x.as_block(),
326            AllWire::Mod3(x) => x.as_block(),
327            AllWire::ModN(x) => x.as_block(),
328        }
329    }
330    fn color(&self) -> u16 {
331        match &self {
332            AllWire::Mod2(x) => x.color(),
333            AllWire::Mod3(x) => x.color(),
334            AllWire::ModN(x) => x.color(),
335        }
336    }
337    fn plus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self {
338        match (&mut *self, other) {
339            (AllWire::Mod2(x), AllWire::Mod2(y)) => {
340                x.plus_eq(y);
341            }
342            (AllWire::Mod3(x), AllWire::Mod3(y)) => {
343                x.plus_eq(y);
344            }
345            (AllWire::ModN(x), AllWire::ModN(y)) => {
346                x.plus_eq(y);
347            }
348            _ => {
349                panic!(
350                    "[AllWire::plus_eq] unequal moduli: {}, {}!",
351                    self.modulus(),
352                    other.modulus()
353                )
354            }
355        };
356        self
357    }
358
359    fn cmul_eq(&mut self, c: u16) -> &mut Self {
360        match &mut *self {
361            AllWire::Mod2(x) => {
362                x.cmul_eq(c);
363            }
364            AllWire::Mod3(x) => {
365                x.cmul_eq(c);
366            }
367            AllWire::ModN(x) => {
368                x.cmul_eq(c);
369            }
370        };
371        self
372    }
373    fn negate_eq(&mut self) -> &mut Self {
374        match &mut *self {
375            AllWire::Mod2(x) => {
376                x.negate_eq();
377            }
378            AllWire::Mod3(x) => {
379                x.negate_eq();
380            }
381            AllWire::ModN(x) => {
382                x.negate_eq();
383            }
384        };
385        self
386    }
387    fn from_block(inp: Block, q: u16) -> Self {
388        match q {
389            2 => AllWire::Mod2(WireMod2::from_block(inp, q)),
390            3 => AllWire::Mod3(WireMod3::from_block(inp, q)),
391            _ => AllWire::ModN(WireModQ::from_block(inp, q)),
392        }
393    }
394
395    fn zero(q: u16) -> Self {
396        match q {
397            2 => AllWire::Mod2(WireMod2::zero(q)),
398            3 => AllWire::Mod3(WireMod3::zero(q)),
399            _ => AllWire::ModN(WireModQ::zero(q)),
400        }
401    }
402
403    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
404        match q {
405            2 => AllWire::Mod2(WireMod2::rand(rng, q)),
406            3 => AllWire::Mod3(WireMod3::rand(rng, q)),
407            _ => AllWire::ModN(WireModQ::rand(rng, q)),
408        }
409    }
410
411    fn hash_to_mod(hash: Block, q: u16) -> Self {
412        if q == 3 {
413            AllWire::Mod3(WireMod3::encode_block_mod3(hash))
414        } else {
415            Self::from_block(hash, q)
416        }
417    }
418}
419
420impl WireMod3 {
421    /// We have to convert `block` into a valid `Mod3` encoding.
422    ///
423    /// We do this by computing the `Mod3` digits using `_unrank`,
424    /// and then map these to a `Mod3` encoding.
425    fn encode_block_mod3(block: Block) -> Self {
426        let mut lsb = 0u64;
427        let mut msb = 0u64;
428        let mut ds = _unrank(u128::from(block), 3);
429        for (i, v) in ds.drain(..64).enumerate() {
430            lsb |= ((v & 1) as u64) << i;
431            msb |= (((v >> 1) & 1u16) as u64) << i;
432        }
433        debug_assert_eq!(lsb & msb, 0);
434        Self { lsb, msb }
435    }
436}
437
438impl WireLabel for WireMod2 {
439    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
440        if q != 2 {
441            panic!("[WireMod2::rand_delta] Expected modulo 2. Got {}", q);
442        }
443        let mut w = Self::rand(rng, q);
444        w.val |= Block::set_lo(1);
445        w
446    }
447
448    fn digits(&self) -> Vec<u16> {
449        (0..128)
450            .map(|i| ((u128::from(self.val) >> i) as u16) & 1)
451            .collect()
452    }
453
454    fn as_block(&self) -> Block {
455        self.val
456    }
457
458    fn color(&self) -> u16 {
459        // This extracts the least-significant bit of the U8x16.
460        (self.val.extract::<0>() & 1) as u16
461    }
462
463    fn plus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self {
464        self.val ^= other.val;
465        self
466    }
467
468    fn cmul_eq(&mut self, c: u16) -> &mut Self {
469        if c & 1 == 0 {
470            self.val = Block::default();
471        }
472        self
473    }
474
475    fn negate_eq(&mut self) -> &mut Self {
476        // Do nothing. Additive inverse is a no-op for mod 2.
477        self
478    }
479
480    fn from_block(inp: Block, q: u16) -> Self {
481        if q != 2 {
482            panic!("[WireMod2::from_block] Expected modulo 2. Got {}", q);
483        }
484        Self { val: inp }
485    }
486
487    fn zero(q: u16) -> Self {
488        if q != 2 {
489            panic!("[WireMod2::zero] Expected modulo 2. Got {}", q);
490        }
491        Self::default()
492    }
493
494    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
495        if q != 2 {
496            panic!("[WireMod2::rand] Expected modulo 2. Got {}", q);
497        }
498
499        Self { val: rng.r#gen() }
500    }
501
502    fn hash_to_mod(hash: Block, q: u16) -> Self {
503        if q != 2 {
504            panic!("[WireMod2::hash_to_mod] Expected modulo 2. Got {}", q);
505        }
506        Self::from_block(hash, q)
507    }
508}
509
510impl WireLabel for WireMod3 {
511    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
512        if q != 3 {
513            panic!("[WireMod3::rand_delta] Expected modulo 3. Got {}", q);
514        }
515        let mut w = Self::rand(rng, 3);
516        w.lsb |= 1;
517        w.msb &= 0xFFFF_FFFF_FFFF_FFFE;
518        w
519    }
520
521    fn digits(&self) -> Vec<u16> {
522        (0..64)
523            .map(|i| (((self.lsb >> i) as u16) & 1) & ((((self.msb >> i) as u16) & 1) << 1))
524            .collect()
525    }
526
527    fn as_block(&self) -> Block {
528        Block::from(((self.msb as u128) << 64) | (self.lsb as u128))
529    }
530
531    fn color(&self) -> u16 {
532        let color = (((self.msb & 1) as u16) << 1) | ((self.lsb & 1) as u16);
533        debug_assert_ne!(color, 3);
534        color
535    }
536
537    fn plus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self {
538        let a1 = &mut self.lsb;
539        let a2 = &mut self.msb;
540        let b1 = other.lsb;
541        let b2 = other.msb;
542
543        let t = (*a1 | b2) ^ (*a2 | b1);
544        let c1 = (*a2 | b2) ^ t;
545        let c2 = (*a1 | b1) ^ t;
546        *a1 = c1;
547        *a2 = c2;
548        self
549    }
550
551    fn cmul_eq(&mut self, c: u16) -> &mut Self {
552        match c {
553            0 => {
554                self.msb = 0;
555                self.lsb = 0;
556            }
557            1 => {}
558            2 => {
559                std::mem::swap(&mut self.lsb, &mut self.msb);
560            }
561            c => {
562                self.cmul_eq(c % 3);
563            }
564        }
565        self
566    }
567
568    fn negate_eq(&mut self) -> &mut Self {
569        // Negation just involves swapping `lsb` and `msb`.
570        std::mem::swap(&mut self.lsb, &mut self.msb);
571        self
572    }
573
574    fn from_block(inp: Block, q: u16) -> Self {
575        if q != 3 {
576            panic!("[WireMod3::from_block] Expected mod 3. Got mod {}", q)
577        }
578        let inp = u128::from(inp);
579        let lsb = inp as u64;
580        let msb = (inp >> 64) as u64;
581        debug_assert_eq!(lsb & msb, 0);
582        Self { lsb, msb }
583    }
584
585    fn zero(q: u16) -> Self {
586        if q != 3 {
587            panic!("[WireMod3::zero] Expected modulo 3. Got {}", q);
588        }
589        Self::default()
590    }
591
592    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
593        if q != 3 {
594            panic!("[WireMod3::rand] Expected mod 3. Got mod {}", q)
595        }
596        let mut lsb = 0u64;
597        let mut msb = 0u64;
598        for (i, v) in (0..64).map(|_| rng.r#gen::<u8>() % 3).enumerate() {
599            lsb |= ((v & 1) as u64) << i;
600            msb |= (((v >> 1) & 1) as u64) << i;
601        }
602        debug_assert_eq!(lsb & msb, 0);
603        Self { lsb, msb }
604    }
605
606    fn hash_to_mod(hash: Block, q: u16) -> Self {
607        if q != 3 {
608            panic!("[WireMod3::hash_to_mod] Expected mod 3. Got mod {}", q)
609        }
610        Self::encode_block_mod3(hash)
611    }
612}
613
614impl WireLabel for WireModQ {
615    fn rand_delta<R: CryptoRng + Rng>(rng: &mut R, q: u16) -> Self {
616        if q < 2 {
617            panic!(
618                "[WireModQ::rand_delta] Modulus must be at least 2. Got {}",
619                q
620            );
621        }
622        let mut w = Self::rand(rng, q);
623        w.ds[0] = 1;
624        w
625    }
626
627    fn digits(&self) -> Vec<u16> {
628        self.ds.clone()
629    }
630
631    fn as_block(&self) -> Block {
632        Block::from(util::from_base_q(&self.ds, self.q))
633    }
634
635    fn color(&self) -> u16 {
636        let color = self.ds[0];
637        debug_assert!(color < self.q);
638        color
639    }
640
641    fn plus_eq<'a>(&'a mut self, other: &Self) -> &'a mut Self {
642        let xs = &mut self.ds;
643        let ys = &other.ds;
644        let q = self.q;
645
646        // Assuming modulus has to be the same here
647        // Will enforce by type system
648        //debug_assert_eq!(, ymod);
649        debug_assert_eq!(xs.len(), ys.len());
650        xs.iter_mut().zip(ys.iter()).for_each(|(x, &y)| {
651            let (zp, overflow) = (*x + y).overflowing_sub(q);
652            *x = if overflow { *x + y } else { zp }
653        });
654
655        self
656    }
657
658    fn cmul_eq(&mut self, c: u16) -> &mut Self {
659        let q = self.q;
660        self.ds
661            .iter_mut()
662            .for_each(|d| *d = (*d as u32 * c as u32 % q as u32) as u16);
663        self
664    }
665
666    fn negate_eq(&mut self) -> &mut Self {
667        let q = self.q;
668        self.ds.iter_mut().for_each(|d| {
669            if *d > 0 {
670                *d = q - *d;
671            } else {
672                *d = 0;
673            }
674        });
675        self
676    }
677    fn from_block(inp: Block, q: u16) -> Self {
678        if q < 2 {
679            panic!(
680                "[WireModQ::from_block] Modulus must be at least 2. Got {}",
681                q
682            );
683        }
684        let ds = if util::is_power_of_2(q) {
685            // It's a power of 2, just split the digits.
686            let ndigits = util::digits_per_u128(q);
687            let width = 128 / ndigits;
688            let mask = (1 << width) - 1;
689            let x = u128::from(inp);
690            (0..ndigits)
691                .map(|i| ((x >> (width * i)) & mask) as u16)
692                .collect::<Vec<u16>>()
693        } else if q <= 23 {
694            _unrank(u128::from(inp), q)
695        } else {
696            // If all else fails, do unrank using naive division.
697            _unrank(u128::from(inp), q)
698        };
699        Self { q, ds }
700    }
701    /// Unpack the wire represented by a `Block` with modulus `q`. Assumes that
702    /// the block was constructed through the `AllWire` API.
703    fn zero(q: u16) -> Self {
704        if q < 2 {
705            panic!("[WireModQ::zero] Modulus must be at least 2. Got {}", q);
706        }
707        Self {
708            q,
709            ds: vec![0; util::digits_per_u128(q)],
710        }
711    }
712    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
713        if q < 2 {
714            panic!("[WireModQ::rand] Modulus must be at least 2. Got {}", q);
715        }
716        let ds = (0..util::digits_per_u128(q))
717            .map(|_| rng.r#gen::<u16>() % q)
718            .collect();
719        Self { q, ds }
720    }
721
722    fn hash_to_mod(hash: Block, q: u16) -> Self {
723        if q < 2 {
724            panic!(
725                "[WireModQ::hash_to_mod] Modulus must be at least 2. Got {}",
726                q
727            );
728        }
729        Self::from_block(hash, q)
730    }
731}
732
733fn _unrank(inp: u128, q: u16) -> Vec<u16> {
734    let mut x = inp;
735    let ndigits = util::digits_per_u128(q);
736    let npaths_tab = npaths_tab::lookup(q);
737    x %= npaths_tab[ndigits - 1] * q as u128;
738
739    let mut ds = vec![0; ndigits];
740    for i in (0..ndigits).rev() {
741        let npaths = npaths_tab[i];
742
743        if q <= 23 {
744            // linear search
745            let mut acc = 0;
746            for j in 0..q {
747                acc += npaths;
748                if acc > x {
749                    x -= acc - npaths;
750                    ds[i] = j;
751                    break;
752                }
753            }
754        } else {
755            // naive division
756            let d = x / npaths;
757            ds[i] = d as u16;
758            x -= d * npaths;
759        }
760        // } else {
761        //     // binary search
762        //     let mut low = 0;
763        //     let mut high = q;
764        //     loop {
765        //         let cur = (low + high) / 2;
766        //         let l = npaths * cur as u128;
767        //         let r = npaths * (cur as u128 + 1);
768        //         if x >= l && x < r {
769        //             x -= l;
770        //             ds[i] = cur;
771        //             break;
772        //         }
773        //         if x < l {
774        //             high = cur;
775        //         } else {
776        //             // x >= r
777        //             low = cur;
778        //         }
779        //     }
780        // }
781    }
782    ds
783}
784
785impl ArithmeticWire for WireMod3 {}
786impl ArithmeticWire for WireModQ {}
787impl ArithmeticWire for AllWire {}
788
789////////////////////////////////////////////////////////////////////////////////
790// tests
791//
792//
793//
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use crate::util::RngExt;
799    use itertools::Itertools;
800    use rand::thread_rng;
801
802    #[test]
803    fn packing() {
804        let rng = &mut thread_rng();
805        for q in 2..256 {
806            for _ in 0..1000 {
807                let w = AllWire::rand(rng, q);
808                assert_eq!(w, AllWire::from_block(w.as_block(), q));
809            }
810        }
811    }
812
813    #[test]
814    fn base_conversion_lookup_method() {
815        let rng = &mut thread_rng();
816        for _ in 0..1000 {
817            let q = 5 + (rng.gen_u16() % 110);
818            let x = rng.gen_u128();
819            let w = AllWire::from_block(Block::from(x), q);
820            let should_be = util::as_base_q_u128(x, q);
821            assert_eq!(w.digits(), should_be, "x={} q={}", x, q);
822        }
823    }
824
825    #[test]
826    fn hash() {
827        let mut rng = thread_rng();
828        for _ in 0..100 {
829            let q = 2 + (rng.gen_u16() % 110);
830            let x = AllWire::rand(&mut rng, q);
831            let y = x.hashback(Block::from(1u128), q);
832            assert!(x != y);
833            match y {
834                AllWire::Mod2(WireMod2 { val }) => assert!(u128::from(val) > 0),
835                AllWire::Mod3(WireMod3 { lsb, msb }) => assert!(lsb > 0 && msb > 0),
836                AllWire::ModN(WireModQ { ds, .. }) => assert!(!ds.iter().all(|&y| y == 0)),
837            }
838        }
839    }
840
841    #[test]
842    fn negation() {
843        let rng = &mut thread_rng();
844        for _ in 0..1000 {
845            let q = rng.gen_modulus();
846            let x = AllWire::rand(rng, q);
847            let xneg = x.negate();
848            if q != 2 {
849                assert!(x != xneg);
850            }
851            let y = xneg.negate();
852            assert_eq!(x, y);
853        }
854    }
855
856    #[test]
857    fn zero() {
858        let mut rng = thread_rng();
859        for _ in 0..1000 {
860            let q = 3 + (rng.gen_u16() % 110);
861            let z = AllWire::zero(q);
862            let ds = z.digits();
863            assert_eq!(ds, vec![0; ds.len()], "q={}", q);
864        }
865    }
866
867    #[test]
868    fn subzero() {
869        let mut rng = thread_rng();
870        for _ in 0..1000 {
871            let q = rng.gen_modulus();
872            let x = AllWire::rand(&mut rng, q);
873            let z = AllWire::zero(q);
874            assert_eq!(x.minus(&x), z);
875        }
876    }
877
878    #[test]
879    fn pluszero() {
880        let mut rng = thread_rng();
881        for _ in 0..1000 {
882            let q = rng.gen_modulus();
883            let x = AllWire::rand(&mut rng, q);
884            assert_eq!(x.plus(&AllWire::zero(q)), x);
885        }
886    }
887
888    #[test]
889    fn arithmetic() {
890        let mut rng = thread_rng();
891        for _ in 0..1024 {
892            let q = rng.gen_modulus();
893            let x = AllWire::rand(&mut rng, q);
894            let y = AllWire::rand(&mut rng, q);
895            assert_eq!(x.cmul(0), AllWire::zero(q));
896            assert_eq!(x.cmul(q), AllWire::zero(q));
897            assert_eq!(x.plus(&x), x.cmul(2));
898            assert_eq!(x.plus(&x).plus(&x), x.cmul(3));
899            assert_eq!(x.negate().negate(), x);
900            if q == 2 {
901                assert_eq!(x.plus(&y), x.minus(&y));
902            } else {
903                assert_eq!(x.plus(&x.negate()), AllWire::zero(q), "q={}", q);
904                assert_eq!(x.minus(&y), x.plus(&y.negate()));
905            }
906            let mut w = x.clone();
907            let z = w.plus(&y);
908            w.plus_eq(&y);
909            assert_eq!(w, z);
910
911            w = x.clone();
912            w.cmul_eq(2);
913            assert_eq!(x.plus(&x), w);
914
915            w = x.clone();
916            w.negate_eq();
917            assert_eq!(x.negate(), w);
918        }
919    }
920
921    #[test]
922    fn ndigits_correct() {
923        let mut rng = thread_rng();
924        for _ in 0..1024 {
925            let q = rng.gen_modulus();
926            let x = AllWire::rand(&mut rng, q);
927            assert_eq!(x.digits().len(), util::digits_per_u128(q));
928        }
929    }
930
931    #[test]
932    fn parallel_hash() {
933        let n = 1000;
934        let mut rng = thread_rng();
935        let q = rng.gen_modulus();
936        let ws = (0..n).map(|_| AllWire::rand(&mut rng, q)).collect_vec();
937
938        let mut handles = Vec::new();
939        for w in ws.iter() {
940            let w_ = w.clone();
941            let h = std::thread::spawn(move || w_.hash(Block::default()));
942            handles.push(h);
943        }
944        let hashes = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
945
946        let should_be = ws.iter().map(|w| w.hash(Block::default())).collect_vec();
947
948        assert_eq!(hashes, should_be);
949    }
950
951    #[cfg(feature = "serde")]
952    #[test]
953    fn test_serialize_mod2() {
954        let mut rng = thread_rng();
955        let w = WireMod2::rand(&mut rng, 2);
956        let serialized = serde_json::to_string(&w).unwrap();
957
958        let deserialized: WireMod2 = serde_json::from_str(&serialized).unwrap();
959
960        assert_eq!(w, deserialized);
961    }
962
963    #[cfg(feature = "serde")]
964    #[test]
965    fn test_serialize_allwire() {
966        let mut rng = thread_rng();
967        for q in 2..16 {
968            let w = AllWire::rand(&mut rng, q);
969            let serialized = serde_json::to_string(&w).unwrap();
970
971            let deserialized: AllWire = serde_json::from_str(&serialized).unwrap();
972
973            assert_eq!(w, deserialized);
974        }
975    }
976
977    #[cfg(feature = "serde")]
978    #[test]
979    fn test_serialize_good_mod3() {
980        let mut rng = thread_rng();
981        let w = WireMod3::rand(&mut rng, 3);
982        let serialized = serde_json::to_string(&w).unwrap();
983
984        let deserialized: WireMod3 = serde_json::from_str(&serialized).unwrap();
985
986        assert_eq!(w, deserialized);
987    }
988
989    #[cfg(feature = "serde")]
990    #[test]
991    fn test_serialize_bad_mod3() {
992        let mut rng = thread_rng();
993        let mut w = WireMod3::rand(&mut rng, 3);
994
995        // lsb and msb can't both be set
996        w.lsb |= 1;
997        w.msb |= 1;
998        let serialized = serde_json::to_string(&w).unwrap();
999
1000        let deserialized: Result<WireMod3, _> = serde_json::from_str(&serialized);
1001        assert!(deserialized.is_err());
1002    }
1003
1004    #[cfg(feature = "serde")]
1005    #[test]
1006    fn test_serialize_good_modQ() {
1007        let mut rng = thread_rng();
1008
1009        for _ in 0..16 {
1010            let mut q: u16 = rng.r#gen();
1011            while q < 2 {
1012                q = rng.r#gen();
1013            }
1014            let w = WireModQ::rand(&mut rng, q);
1015            let serialized = serde_json::to_string(&w).unwrap();
1016
1017            let deserialized: WireModQ = serde_json::from_str(&serialized).unwrap();
1018
1019            assert_eq!(w, deserialized);
1020        }
1021    }
1022    #[cfg(feature = "serde")]
1023    #[test]
1024    fn test_serialize_bad_modQ_mod() {
1025        let mut rng = thread_rng();
1026        let mut q: u16 = rng.r#gen();
1027        while q < 2 {
1028            q = rng.r#gen();
1029        }
1030
1031        let mut w = WireModQ::rand(&mut rng, q);
1032
1033        // Manually mess with the modulus
1034        w.q = 1;
1035        let serialized = serde_json::to_string(&w).unwrap();
1036
1037        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
1038        assert!(deserialized.is_err());
1039    }
1040    #[cfg(feature = "serde")]
1041    #[test]
1042    fn test_serialize_bad_modQ_ds_mod() {
1043        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,5,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
1044
1045        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
1046        assert!(deserialized.is_err());
1047    }
1048
1049    #[cfg(feature = "serde")]
1050    #[test]
1051    fn test_serialize_bad_modQ_ds_count() {
1052        let serialized: String = "{\"q\":2,\"ds\":[1,1,0,1,0,1,0,0,0,1,1,1,0,0,1,1,0,1,1,1,0,0,0,1,1,0,0,1,1,0,0,0,1,0,1,1,0,1,1,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,1,0,0,1,1,1,1,1,0,1,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1,1,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0]}".to_string();
1053
1054        let deserialized: Result<WireModQ, _> = serde_json::from_str(&serialized);
1055        assert!(deserialized.is_err());
1056    }
1057}