vectoreyes/
utils.rs

1use crate::{SimdBase8, U8x16, U32x4, U64x2};
2
3// For compatibility with scuttlebutt::Block
4impl From<u128> for U8x16 {
5    fn from(value: u128) -> Self {
6        bytemuck::cast(value)
7    }
8}
9
10// For compatibility with scuttlebutt::Block
11impl From<U8x16> for u128 {
12    fn from(value: U8x16) -> Self {
13        bytemuck::cast(value)
14    }
15}
16
17impl U8x16 {
18    /// Perform a (full) 128-bit wide carryless multiply
19    ///
20    /// The result of the 128-bit wide carryless multiply is 256-bits. This is returned as
21    /// two 128-bit values `[lower_bits, upper_bits]`.
22    ///
23    /// If you'd like a single 256-bit value, it can be constructed like
24    /// ```
25    /// # use vectoreyes::{U8x16, U8x32};
26    /// let a = U8x16::from(3);
27    /// let b = U8x16::from(7);
28    /// let product: [U8x16; 2] = a.carryless_mul_wide(b);
29    /// let product: U8x32 = product.into();
30    /// # let _ = product;
31    /// ```
32    ///
33    /// _(This function doesn't always return a `U8x32`, since it will use `__m128i` for
34    /// computation on x86_64 machines, and it may be slower to always construct a `__m256i`)_
35    #[inline(always)]
36    pub fn carryless_mul_wide(self, b: Self) -> [Self; 2] {
37        #[inline(always)]
38        fn upper_bits_made_lower(a: U64x2) -> U64x2 {
39            U64x2::from(U8x16::from(a).shift_bytes_right::<8>())
40        }
41
42        #[inline(always)]
43        fn lower_bits_made_upper(a: U64x2) -> U64x2 {
44            U64x2::from(U8x16::from(a).shift_bytes_left::<8>())
45        }
46        // See algorithm 2 on page 12 of https://web.archive.org/web/20191130175212/https://www.intel.com/content/dam/www/public/us/en/documents/white-papers/carry-less-multiplication-instruction-in-gcm-mode-paper.pdf
47        let a: U64x2 = bytemuck::cast(self);
48        let b: U64x2 = bytemuck::cast(b);
49        let c = a.carryless_mul::<true, true>(b);
50        let d = a.carryless_mul::<false, false>(b);
51        // CLMUL(lower bits of a ^ upper bits of a, lower bits of b ^ upper bits of b)
52        let e = (a ^ upper_bits_made_lower(a))
53            .carryless_mul::<false, false>(b ^ upper_bits_made_lower(b));
54        let product_upper_half =
55            c ^ upper_bits_made_lower(c) ^ upper_bits_made_lower(d) ^ upper_bits_made_lower(e);
56        let product_lower_half =
57            d ^ lower_bits_made_upper(d) ^ lower_bits_made_upper(c) ^ lower_bits_made_upper(e);
58        [
59            bytemuck::cast(product_lower_half),
60            bytemuck::cast(product_upper_half),
61        ]
62    }
63}
64
65#[test]
66fn test_carryless_mul_wide() {
67    // Test some random test vectors.
68    assert_eq!(
69        U8x16::from(113718949524325212707291430558820879029)
70            .carryless_mul_wide(U8x16::from(305595614614064458589355305592899341783)),
71        [
72            U8x16::from(181870553715282462853040151492428488859),
73            U8x16::from(69303674900886469910632566104075007218)
74        ]
75    );
76    assert_eq!(
77        U8x16::from(305491409529336450059265117908006794202)
78            .carryless_mul_wide(U8x16::from(331330386820708447646441739307072964010)),
79        [
80            U8x16::from(127269516908168038593688997658496458020),
81            U8x16::from(125659689760004568937468201162182112345)
82        ]
83    );
84    assert_eq!(
85        U8x16::from(267625637845811074182836635736437393132)
86            .carryless_mul_wide(U8x16::from(98247896988070748377279692417561622532)),
87        [
88            U8x16::from(47973638020603525196354339630722399152),
89            U8x16::from(69947343163265692377803117866524991745)
90        ]
91    );
92}
93
94// These won't be used for the scalar backend, hence the allow(unused).
95#[allow(unused)]
96#[derive(Clone)]
97pub(crate) struct AesEncryptOnlyKeySchedule<const NUM_ROUNDS: usize> {
98    pub(crate) encrypt_keys: [U32x4; NUM_ROUNDS],
99}
100#[allow(unused)]
101#[derive(Clone)]
102pub(crate) struct AesKeySchedule<const NUM_ROUNDS: usize> {
103    pub(crate) encrypt_keys: [U32x4; NUM_ROUNDS],
104    pub(crate) decrypt_keys: [U32x4; NUM_ROUNDS],
105}