fancy_garbling/circuits/binary/
binary_shift.rs1use crate::{BinaryBundle, FancyBinary, circuit::Circuit};
2use core::marker::PhantomData;
3use swanky_channel::Channel;
4use swanky_error::Result;
5
6#[derive(Default)]
9pub struct BinaryLeftShift<'a>(PhantomData<&'a ()>);
10
11impl<'a> BinaryLeftShift<'a> {
12 pub fn new() -> Self {
14 Default::default()
15 }
16}
17
18impl<'a, F: FancyBinary> Circuit<F> for BinaryLeftShift<'a>
19where
20 F::Item: 'a,
21{
22 type Input = (&'a BinaryBundle<F::Item>, usize);
23 type Output = BinaryBundle<F::Item>;
24
25 fn execute(
26 &self,
27 backend: &mut F,
28 inputs: Self::Input,
29 channel: &mut Channel,
30 ) -> Result<Self::Output> {
31 let (bundle, n) = inputs;
32 let zero = backend.constant(0, 2, channel)?;
33
34 let mut wires = bundle.wires().to_vec();
35 for _ in 0..n {
36 wires.pop();
37 wires.insert(0, zero.clone());
38 }
39 Ok(BinaryBundle::new(wires))
40 }
41}
42
43#[derive(Default)]
48pub struct BinaryLeftShiftExtend<'a>(PhantomData<&'a ()>);
49
50impl<'a> BinaryLeftShiftExtend<'a> {
51 pub fn new() -> Self {
53 Default::default()
54 }
55}
56
57impl<'a, F: FancyBinary> Circuit<F> for BinaryLeftShiftExtend<'a>
58where
59 F::Item: 'a,
60{
61 type Input = (&'a BinaryBundle<F::Item>, usize);
62 type Output = BinaryBundle<F::Item>;
63
64 fn execute(
65 &self,
66 backend: &mut F,
67 inputs: Self::Input,
68 channel: &mut Channel,
69 ) -> Result<Self::Output> {
70 let (bundle, n) = inputs;
71 let mut wires = bundle.wires().to_vec();
72 let zero = backend.constant(0, 2, channel)?;
73 for _ in 0..n {
74 wires.insert(0, zero.clone());
75 }
76 Ok(BinaryBundle::new(wires))
77 }
78}
79
80#[derive(Default)]
83pub struct BinaryRightShift<'a>(PhantomData<&'a ()>);
84
85impl<'a> BinaryRightShift<'a> {
86 pub fn new() -> Self {
88 Default::default()
89 }
90}
91
92impl<'a, F: FancyBinary> Circuit<F> for BinaryRightShift<'a>
93where
94 F::Item: 'a,
95{
96 type Input = (&'a BinaryBundle<F::Item>, usize, F::Item);
97 type Output = BinaryBundle<F::Item>;
98
99 fn execute(&self, _: &mut F, inputs: Self::Input, _: &mut Channel) -> Result<Self::Output> {
100 let (x, n, pad) = inputs;
101 let mut wires: Vec<_> = Vec::with_capacity(x.wires().len());
102
103 for i in 0..x.wires().len() {
104 let src_idx = i + n;
105 if src_idx >= x.wires().len() {
106 wires.push(pad.clone())
107 } else {
108 wires.push(x.wires()[src_idx].clone())
109 }
110 }
111 Ok(BinaryBundle::new(wires))
112 }
113}
114
115#[derive(Default)]
117pub struct BinaryLogicalRightShift<'a>(PhantomData<&'a ()>);
118
119impl<'a> BinaryLogicalRightShift<'a> {
120 pub fn new() -> Self {
122 Default::default()
123 }
124}
125
126impl<'a, F: FancyBinary> Circuit<F> for BinaryLogicalRightShift<'a>
127where
128 F::Item: 'a,
129{
130 type Input = (&'a BinaryBundle<F::Item>, usize);
131 type Output = BinaryBundle<F::Item>;
132
133 fn execute(
134 &self,
135 backend: &mut F,
136 inputs: Self::Input,
137 channel: &mut Channel,
138 ) -> Result<Self::Output> {
139 let (x, n) = inputs;
140 let zero = backend.constant(0, 2, channel)?;
141 BinaryRightShift::new().execute(backend, (x, n, zero), channel)
142 }
143}
144
145#[derive(Default)]
147pub struct BinaryArithmeticRightShift<'a>(PhantomData<&'a ()>);
148
149impl<'a> BinaryArithmeticRightShift<'a> {
150 pub fn new() -> Self {
152 Default::default()
153 }
154}
155
156impl<'a, F: FancyBinary> Circuit<F> for BinaryArithmeticRightShift<'a>
157where
158 F::Item: 'a,
159{
160 type Input = (&'a BinaryBundle<F::Item>, usize);
161 type Output = BinaryBundle<F::Item>;
162
163 fn execute(
164 &self,
165 backend: &mut F,
166 inputs: Self::Input,
167 channel: &mut Channel,
168 ) -> Result<Self::Output> {
169 let (x, n) = inputs;
170 let pad = x.wires().last().unwrap();
171 BinaryRightShift::new().execute(backend, (x, n, pad.clone()), channel)
172 }
173}
174
175#[cfg(test)]
176mod test {
177 use crate::{
178 circuits::binary::{
179 BinaryArithmeticRightShift, BinaryLeftShift, BinaryLeftShiftExtend,
180 binary_shift::BinaryLogicalRightShift,
181 },
182 dummy::{Dummy, DummyVal},
183 };
184 use rand::{Rng, thread_rng};
185
186 #[test]
187 fn left_shift() {
188 const N: usize = 64;
189 let mut rng = thread_rng();
190
191 for _ in 0..16 {
192 let shift_size = rng.r#gen::<usize>() % N;
193 let x = rng.r#gen::<u64>();
194 let input = DummyVal::to_binary(x as u128, N);
195 let output =
196 Dummy::eval(&BinaryLeftShift::new(), (&input, shift_size as usize)).unwrap();
197 assert_eq!(
198 DummyVal::from_binary(&output) as u64,
199 x.wrapping_shl(shift_size as u32)
200 );
201 }
202 }
203
204 #[test]
205 fn left_shift_extend() {
206 let mut rng = thread_rng();
207 let nbits = 64;
208 let q = 1 << nbits;
209
210 for _ in 0..16 {
211 let shift_size = rng.r#gen::<usize>() % nbits;
212 let x = rng.r#gen::<u128>() % q;
213 let input = DummyVal::to_binary(x, nbits);
214 let output = Dummy::eval(&BinaryLeftShiftExtend::new(), (&input, shift_size)).unwrap();
215 assert_eq!(DummyVal::from_binary(&output), x << shift_size);
216 }
217 }
218
219 #[test]
220 fn logical_right_shift() {
221 const N: usize = 64;
222 let mut rng = thread_rng();
223
224 for _ in 0..16 {
225 let shift_size = rng.r#gen::<usize>() % N;
226 let x = rng.r#gen::<u64>();
227 let input = DummyVal::to_binary(x as u128, N);
228 let output = Dummy::eval(
229 &BinaryLogicalRightShift::new(),
230 (&input, shift_size as usize),
231 )
232 .unwrap();
233 assert_eq!(DummyVal::from_binary(&output) as u64, x >> shift_size);
234 }
235 }
236
237 #[test]
238 fn arithmetic_right_shift() {
239 const N: usize = 64;
240 const Q: u128 = 1 << N;
241 let mut rng = thread_rng();
242
243 for _ in 0..16 {
244 let x = rng.r#gen::<u128>() % Q;
245 let shift_size = rng.r#gen::<usize>() % N;
246 let x_input = DummyVal::to_binary(x, N);
247 let output =
248 Dummy::eval(&BinaryArithmeticRightShift::new(), (&x_input, shift_size)).unwrap();
249 assert_eq!(
250 DummyVal::from_binary(&output) as i64,
251 (x as i64) >> shift_size
252 );
253 }
254 }
255}