fancy_garbling/circuits/binary/
binary_subtraction.rs1use crate::{
2 BinaryBundle, FancyBinary,
3 circuit::Circuit,
4 circuits::binary::{BinaryAddition, BinaryTwosComplement},
5};
6use core::marker::PhantomData;
7use swanky_channel::Channel;
8use swanky_error::Result;
9
10#[derive(Default)]
15pub struct BinarySubtraction<'a>(PhantomData<&'a ()>);
16
17impl<'a> BinarySubtraction<'a> {
18 pub fn new() -> Self {
20 Default::default()
21 }
22}
23
24impl<'a, F: FancyBinary> Circuit<F> for BinarySubtraction<'a>
25where
26 F::Item: 'a,
27{
28 type Input = (&'a BinaryBundle<F::Item>, &'a BinaryBundle<F::Item>);
29 type Output = (BinaryBundle<F::Item>, F::Item);
30
31 fn execute(
32 &self,
33 backend: &mut F,
34 inputs: Self::Input,
35 channel: &mut Channel,
36 ) -> Result<Self::Output> {
37 let (x, y) = inputs;
38 assert_eq!(x.moduli(), y.moduli());
39 let neg_y = BinaryTwosComplement::new().execute(backend, y, channel)?;
40 BinaryAddition::new().execute(backend, (x, &neg_y), channel)
41 }
42}
43
44pub mod test {
45 use super::*;
46 use crate::circuit::CircuitInputMapper;
47
48 pub struct TestBinarySubtraction(pub usize);
50 impl<F: FancyBinary> Circuit<F> for TestBinarySubtraction {
51 type Input = (BinaryBundle<F::Item>, BinaryBundle<F::Item>);
52 type Output = (BinaryBundle<F::Item>, F::Item);
53
54 fn execute(
55 &self,
56 backend: &mut F,
57 inputs: Self::Input,
58 channel: &mut Channel,
59 ) -> Result<Self::Output> {
60 BinarySubtraction::new().execute(backend, (&inputs.0, &inputs.1), channel)
61 }
62 }
63
64 impl<F: FancyBinary> CircuitInputMapper<F> for TestBinarySubtraction {
65 fn map(&self, inputs: Vec<<F as crate::Fancy>::Item>) -> Self::Input {
66 assert_eq!(inputs.len(), self.0 * 2);
67 let (x, y) = inputs.split_at(self.0);
68 (BinaryBundle::new(x.to_vec()), BinaryBundle::new(y.to_vec()))
69 }
70
71 fn ninputs(&self) -> usize {
72 self.0 * 2
73 }
74
75 fn modulus(&self, _: usize) -> u16 {
76 2
77 }
78 }
79
80 #[test]
81 fn binary_subtraction() {
82 use crate::dummy::{Dummy, DummyVal};
83 use rand::Rng;
84
85 let mut rng = rand::thread_rng();
86 let nbits = 64;
87 let q = 1 << nbits;
88 let c = TestBinarySubtraction(nbits);
89
90 for _ in 0..16 {
91 let x = rng.r#gen::<u128>() % q;
92 let y = rng.r#gen::<u128>() % q;
93 let x_input = DummyVal::to_binary(x, nbits);
94 let y_input = DummyVal::to_binary(y, nbits);
95 let outputs = Dummy::eval(&c, (x_input, y_input)).unwrap();
96 assert_eq!(
97 DummyVal::from_binary(&outputs.0),
98 x.overflowing_sub(y).0 % q
99 );
100 assert_eq!(outputs.1.val(), (y != 0 && x >= y) as u16);
101 }
102 }
103}