fancy_garbling/wire/
mod3.rs

1#[cfg(feature = "serde")]
2use crate::errors::WireDeserializationError;
3use crate::{ArithmeticWire, HasModulus, WireLabel, wire::_unrank};
4use rand::{CryptoRng, Rng, RngCore};
5use vectoreyes::U8x16;
6
7/// Intermediate struct to deserialize WireMod3 to
8///
9/// Checks that both lsb and msb are not set before allowing to convert to WireMod3
10#[cfg(feature = "serde")]
11#[derive(serde::Deserialize)]
12struct UntrustedWireMod3 {
13    /// The least-significant bits of each `mod-3` element.
14    lsb: u64,
15    /// The most-significant bits of each `mod-3` element.
16    msb: u64,
17}
18
19#[cfg(feature = "serde")]
20impl TryFrom<UntrustedWireMod3> for WireMod3 {
21    type Error = WireDeserializationError;
22
23    fn try_from(wire: UntrustedWireMod3) -> Result<Self, Self::Error> {
24        if wire.lsb & wire.msb != 0 {
25            return Err(Self::Error::InvalidWireMod3);
26        }
27        Ok(WireMod3 {
28            lsb: wire.lsb,
29            msb: wire.msb,
30        })
31    }
32}
33
34/// Representation of a `mod-3` wire.
35///
36/// We represent a `mod-3` wire by 64 `mod-3` elements. These elements are
37/// stored as follows: the least-significant bits of each element are stored
38/// in `lsb` and the most-significant bits of each element are stored in
39/// `msb`. This representation allows for efficient addition and
40/// multiplication as described here by the paper "Hardware Implementation
41/// of Finite Fields of Characteristic Three." D. Page, N.P. Smart. CHES
42/// 2002. Link:
43/// <https://link.springer.com/content/pdf/10.1007/3-540-36400-5_38.pdf>.
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[cfg_attr(feature = "serde", serde(try_from = "UntrustedWireMod3"))]
46#[derive(Debug, Clone, Copy, PartialEq, Default)]
47pub struct WireMod3 {
48    /// The least-significant bits of each `mod-3` element.
49    pub(crate) lsb: u64,
50    /// The most-significant bits of each `mod-3` element.
51    pub(crate) msb: u64,
52}
53
54impl HasModulus for WireMod3 {
55    fn modulus(&self) -> u16 {
56        3
57    }
58}
59
60impl core::ops::Add for WireMod3 {
61    type Output = Self;
62
63    fn add(self, rhs: Self) -> Self::Output {
64        let a1 = self.lsb;
65        let a2 = self.msb;
66        let b1 = rhs.lsb;
67        let b2 = rhs.msb;
68
69        let t = (a1 | b2) ^ (a2 | b1);
70        let c1 = (a2 | b2) ^ t;
71        let c2 = (a1 | b1) ^ t;
72        Self { lsb: c1, msb: c2 }
73    }
74}
75
76impl core::ops::AddAssign for WireMod3 {
77    fn add_assign(&mut self, rhs: Self) {
78        *self = *self + rhs;
79    }
80}
81
82impl core::ops::Sub for WireMod3 {
83    type Output = Self;
84
85    fn sub(self, rhs: Self) -> Self::Output {
86        self + -rhs
87    }
88}
89
90impl core::ops::SubAssign for WireMod3 {
91    fn sub_assign(&mut self, rhs: Self) {
92        *self = *self - rhs;
93    }
94}
95
96impl core::ops::Neg for WireMod3 {
97    type Output = Self;
98
99    fn neg(self) -> Self::Output {
100        // Negation just involves swapping `lsb` and `msb`.
101        let mut output = self;
102        std::mem::swap(&mut output.lsb, &mut output.msb);
103        output
104    }
105}
106
107impl core::ops::Mul<u16> for WireMod3 {
108    type Output = Self;
109
110    #[allow(clippy::suspicious_arithmetic_impl)]
111    fn mul(self, rhs: u16) -> Self::Output {
112        let c = rhs % 3;
113        match c {
114            0 => Self { msb: 0, lsb: 0 },
115            1 => self,
116            2 => Self {
117                msb: self.lsb,
118                lsb: self.msb,
119            },
120            _ => unreachable!("Due to initial `rhs % 3`"),
121        }
122    }
123}
124
125impl core::ops::MulAssign<u16> for WireMod3 {
126    #[allow(clippy::suspicious_op_assign_impl)]
127    fn mul_assign(&mut self, rhs: u16) {
128        let c = rhs % 3;
129        match c {
130            0 => {
131                self.msb = 0;
132                self.lsb = 0;
133            }
134            1 => {}
135            2 => {
136                std::mem::swap(&mut self.lsb, &mut self.msb);
137            }
138            _ => unreachable!("Due to initial `rhs % 3`"),
139        }
140    }
141}
142
143impl WireMod3 {
144    /// We have to convert `block` into a valid `Mod3` encoding.
145    ///
146    /// We do this by computing the `Mod3` digits using `_unrank`,
147    /// and then map these to a `Mod3` encoding.
148    pub(crate) fn encode_block_mod3(block: U8x16) -> Self {
149        let mut lsb = 0u64;
150        let mut msb = 0u64;
151        let mut ds = _unrank(u128::from(block), 3);
152        for (i, v) in ds.drain(..64).enumerate() {
153            lsb |= ((v & 1) as u64) << i;
154            msb |= (((v >> 1) & 1u16) as u64) << i;
155        }
156        debug_assert_eq!(lsb & msb, 0);
157        Self { lsb, msb }
158    }
159}
160
161impl WireLabel for WireMod3 {
162    fn rand_delta<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
163        if q != 3 {
164            panic!("[WireMod3::rand_delta] Expected modulo 3. Got {}", q);
165        }
166        let mut w = Self::rand(rng, 3);
167        w.lsb |= 1;
168        w.msb &= 0xFFFF_FFFF_FFFF_FFFE;
169        w
170    }
171
172    fn digits(&self) -> Vec<u16> {
173        (0..64)
174            .map(|i| (((self.lsb >> i) as u16) & 1) & ((((self.msb >> i) as u16) & 1) << 1))
175            .collect()
176    }
177
178    fn to_repr(&self) -> U8x16 {
179        // This function converts a [`WireMod3`] into its [`Block`] representation.
180        // The two 64b values stored in [`WireMod3`], i.e. the lsb and msb, and packed
181        // into a 128b value as a [`Block`].
182        (((self.msb as u128) << 64) | (self.lsb as u128)).into()
183    }
184
185    fn color(&self) -> u16 {
186        let color = (((self.msb & 1) as u16) << 1) | ((self.lsb & 1) as u16);
187        debug_assert_ne!(color, 3);
188        color
189    }
190
191    fn from_repr(inp: U8x16, q: u16) -> Self {
192        if q != 3 {
193            panic!("[WireMod3::from_block] Expected mod 3. Got mod {}", q)
194        }
195        // This function converts a Block into its WireLabel representation
196        // by splitting the Block into two u64, its least significant bits and
197        // its most significant bits.
198        let inp = u128::from(inp);
199        let lsb = inp as u64;
200        let msb = (inp >> 64) as u64;
201        debug_assert_eq!(lsb & msb, 0);
202        Self { lsb, msb }
203    }
204
205    fn zero(q: u16) -> Self {
206        if q != 3 {
207            panic!("[WireMod3::zero] Expected modulo 3. Got {}", q);
208        }
209        Self::default()
210    }
211
212    fn rand<R: CryptoRng + RngCore>(rng: &mut R, q: u16) -> Self {
213        if q != 3 {
214            panic!("[WireMod3::rand] Expected mod 3. Got mod {}", q)
215        }
216        let mut lsb = 0u64;
217        let mut msb = 0u64;
218        for (i, v) in (0..64).map(|_| rng.r#gen::<u8>() % 3).enumerate() {
219            lsb |= ((v & 1) as u64) << i;
220            msb |= (((v >> 1) & 1) as u64) << i;
221        }
222        debug_assert_eq!(lsb & msb, 0);
223        Self { lsb, msb }
224    }
225
226    fn hash_to_mod(hash: U8x16, q: u16) -> Self {
227        if q != 3 {
228            panic!("[WireMod3::hash_to_mod] Expected mod 3. Got mod {}", q)
229        }
230        Self::encode_block_mod3(hash)
231    }
232}
233
234impl ArithmeticWire for WireMod3 {}
235
236#[cfg(test)]
237mod tests {
238    #[cfg(feature = "serde")]
239    #[test]
240    fn test_serialize_good_mod3() {
241        use crate::{WireLabel, WireMod3};
242        use rand::thread_rng;
243
244        let mut rng = thread_rng();
245        let w = WireMod3::rand(&mut rng, 3);
246        let serialized = serde_json::to_string(&w).unwrap();
247
248        let deserialized: WireMod3 = serde_json::from_str(&serialized).unwrap();
249
250        assert_eq!(w, deserialized);
251    }
252
253    #[cfg(feature = "serde")]
254    #[test]
255    fn test_serialize_bad_mod3() {
256        use crate::{WireLabel, WireMod3};
257        use rand::thread_rng;
258
259        let mut rng = thread_rng();
260        let mut w = WireMod3::rand(&mut rng, 3);
261
262        // lsb and msb can't both be set
263        w.lsb |= 1;
264        w.msb |= 1;
265        let serialized = serde_json::to_string(&w).unwrap();
266
267        let deserialized: Result<WireMod3, _> = serde_json::from_str(&serialized);
268        assert!(deserialized.is_err());
269    }
270}