Skip to main content

fancy_garbling/circuits/binary/
binary_shift.rs

1use crate::{BinaryBundle, FancyBinary, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6/// For a [`BinaryBundle`] `x` and an integer `n`, shift `x` left by `n`,
7/// retaining the size of `x`.
8#[derive(Default)]
9pub struct BinaryLeftShift<'a>(PhantomData<&'a ()>);
10
11impl<'a> BinaryLeftShift<'a> {
12    /// Create a new [`BinaryLeftShift`] circuit.
13    pub fn new() -> Self {
14        Default::default()
15    }
16}
17
18impl<'a, F: FancyBinary> Circuit<F> for BinaryLeftShift<'a>
19where
20    F::Item: 'a,
21{
22    type Input = (&'a BinaryBundle<F::Item>, usize);
23    type Output = BinaryBundle<F::Item>;
24
25    fn execute(
26        &self,
27        backend: &mut F,
28        inputs: Self::Input,
29        channel: &mut Channel,
30    ) -> Result<Self::Output> {
31        let (bundle, n) = inputs;
32        let zero = backend.constant(0, 2, channel)?;
33
34        let mut wires = bundle.wires().to_vec();
35        for _ in 0..n {
36            wires.pop();
37            wires.insert(0, zero.clone());
38        }
39        Ok(BinaryBundle::new(wires))
40    }
41}
42
43/// For a [`BinaryBundle`] `x` and an integer `n`, shift `x` left by `n` 0s,
44/// extending the [`BinaryBundle`].
45///
46/// That is, $`x = x_1,...,x_m`$ becomes $`x_1,...,x_m,0_1,...,0_n`$.
47#[derive(Default)]
48pub struct BinaryLeftShiftExtend<'a>(PhantomData<&'a ()>);
49
50impl<'a> BinaryLeftShiftExtend<'a> {
51    /// Create a new [`BinaryLeftShiftExtend`] circuit.
52    pub fn new() -> Self {
53        Default::default()
54    }
55}
56
57impl<'a, F: FancyBinary> Circuit<F> for BinaryLeftShiftExtend<'a>
58where
59    F::Item: 'a,
60{
61    type Input = (&'a BinaryBundle<F::Item>, usize);
62    type Output = BinaryBundle<F::Item>;
63
64    fn execute(
65        &self,
66        backend: &mut F,
67        inputs: Self::Input,
68        channel: &mut Channel,
69    ) -> Result<Self::Output> {
70        let (bundle, n) = inputs;
71        let mut wires = bundle.wires().to_vec();
72        let zero = backend.constant(0, 2, channel)?;
73        for _ in 0..n {
74            wires.insert(0, zero.clone());
75        }
76        Ok(BinaryBundle::new(wires))
77    }
78}
79
80/// For a [`BinaryBundle`] `x`, integer `n`, and pad `c`, shift `x` right by
81/// `n`, retaining the size of `x` and filling space on the left by `c`.
82#[derive(Default)]
83pub struct BinaryRightShift<'a>(PhantomData<&'a ()>);
84
85impl<'a> BinaryRightShift<'a> {
86    /// Create a new [`BinaryRightShift`] circuit.
87    pub fn new() -> Self {
88        Default::default()
89    }
90}
91
92impl<'a, F: FancyBinary> Circuit<F> for BinaryRightShift<'a>
93where
94    F::Item: 'a,
95{
96    type Input = (&'a BinaryBundle<F::Item>, usize, F::Item);
97    type Output = BinaryBundle<F::Item>;
98
99    fn execute(&self, _: &mut F, inputs: Self::Input, _: &mut Channel) -> Result<Self::Output> {
100        let (x, n, pad) = inputs;
101        let mut wires: Vec<_> = Vec::with_capacity(x.wires().len());
102
103        for i in 0..x.wires().len() {
104            let src_idx = i + n;
105            if src_idx >= x.wires().len() {
106                wires.push(pad.clone())
107            } else {
108                wires.push(x.wires()[src_idx].clone())
109            }
110        }
111        Ok(BinaryBundle::new(wires))
112    }
113}
114
115/// Logical right shift.
116#[derive(Default)]
117pub struct BinaryLogicalRightShift<'a>(PhantomData<&'a ()>);
118
119impl<'a> BinaryLogicalRightShift<'a> {
120    /// Create a new [`BinaryLogicalRightShift`] circuit.
121    pub fn new() -> Self {
122        Default::default()
123    }
124}
125
126impl<'a, F: FancyBinary> Circuit<F> for BinaryLogicalRightShift<'a>
127where
128    F::Item: 'a,
129{
130    type Input = (&'a BinaryBundle<F::Item>, usize);
131    type Output = BinaryBundle<F::Item>;
132
133    fn execute(
134        &self,
135        backend: &mut F,
136        inputs: Self::Input,
137        channel: &mut Channel,
138    ) -> Result<Self::Output> {
139        let (x, n) = inputs;
140        let zero = backend.constant(0, 2, channel)?;
141        BinaryRightShift::new().execute(backend, (x, n, zero), channel)
142    }
143}
144
145/// Arithmetic right shift.
146#[derive(Default)]
147pub struct BinaryArithmeticRightShift<'a>(PhantomData<&'a ()>);
148
149impl<'a> BinaryArithmeticRightShift<'a> {
150    /// Create a new [`BinaryArithmeticRightShift`] circuit.
151    pub fn new() -> Self {
152        Default::default()
153    }
154}
155
156impl<'a, F: FancyBinary> Circuit<F> for BinaryArithmeticRightShift<'a>
157where
158    F::Item: 'a,
159{
160    type Input = (&'a BinaryBundle<F::Item>, usize);
161    type Output = BinaryBundle<F::Item>;
162
163    fn execute(
164        &self,
165        backend: &mut F,
166        inputs: Self::Input,
167        channel: &mut Channel,
168    ) -> Result<Self::Output> {
169        let (x, n) = inputs;
170        let pad = x.wires().last().unwrap();
171        BinaryRightShift::new().execute(backend, (x, n, pad.clone()), channel)
172    }
173}
174
175#[cfg(test)]
176mod test {
177    use crate::{
178        circuits::binary::{
179            BinaryArithmeticRightShift, BinaryLeftShift, BinaryLeftShiftExtend,
180            binary_shift::BinaryLogicalRightShift,
181        },
182        dummy::{Dummy, DummyVal},
183    };
184    use rand::{Rng, thread_rng};
185
186    #[test]
187    fn left_shift() {
188        const N: usize = 64;
189        let mut rng = thread_rng();
190
191        for _ in 0..16 {
192            let shift_size = rng.r#gen::<usize>() % N;
193            let x = rng.r#gen::<u64>();
194            let input = DummyVal::to_binary(x as u128, N);
195            let output =
196                Dummy::eval(&BinaryLeftShift::new(), (&input, shift_size as usize)).unwrap();
197            assert_eq!(
198                DummyVal::from_binary(&output) as u64,
199                x.wrapping_shl(shift_size as u32)
200            );
201        }
202    }
203
204    #[test]
205    fn left_shift_extend() {
206        let mut rng = thread_rng();
207        let nbits = 64;
208        let q = 1 << nbits;
209
210        for _ in 0..16 {
211            let shift_size = rng.r#gen::<usize>() % nbits;
212            let x = rng.r#gen::<u128>() % q;
213            let input = DummyVal::to_binary(x, nbits);
214            let output = Dummy::eval(&BinaryLeftShiftExtend::new(), (&input, shift_size)).unwrap();
215            assert_eq!(DummyVal::from_binary(&output), x << shift_size);
216        }
217    }
218
219    #[test]
220    fn logical_right_shift() {
221        const N: usize = 64;
222        let mut rng = thread_rng();
223
224        for _ in 0..16 {
225            let shift_size = rng.r#gen::<usize>() % N;
226            let x = rng.r#gen::<u64>();
227            let input = DummyVal::to_binary(x as u128, N);
228            let output = Dummy::eval(
229                &BinaryLogicalRightShift::new(),
230                (&input, shift_size as usize),
231            )
232            .unwrap();
233            assert_eq!(DummyVal::from_binary(&output) as u64, x >> shift_size);
234        }
235    }
236
237    #[test]
238    fn arithmetic_right_shift() {
239        const N: usize = 64;
240        const Q: u128 = 1 << N;
241        let mut rng = thread_rng();
242
243        for _ in 0..16 {
244            let x = rng.r#gen::<u128>() % Q;
245            let shift_size = rng.r#gen::<usize>() % N;
246            let x_input = DummyVal::to_binary(x, N);
247            let output =
248                Dummy::eval(&BinaryArithmeticRightShift::new(), (&x_input, shift_size)).unwrap();
249            assert_eq!(
250                DummyVal::from_binary(&output) as i64,
251                (x as i64) >> shift_size
252            );
253        }
254    }
255}