fancy_garbling/circuits/arithmetic/
pmr.rs1use crate::{
2 CrtBundle, FancyArithmetic, FancyBinary, FancyProj, HasModulus,
3 circuit::Circuit,
4 circuits::arithmetic::{ModChange, Subtraction},
5 util::inv,
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11#[derive(Default)]
13pub struct ToPmr<'a>(PhantomData<&'a ()>);
14
15impl<'a> ToPmr<'a> {
16 pub fn new() -> Self {
18 Default::default()
19 }
20}
21
22impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for ToPmr<'a>
23where
24 F::Item: 'a,
25{
26 type Input = &'a CrtBundle<F::Item>;
27 type Output = CrtBundle<F::Item>;
28
29 fn execute(
30 &self,
31 backend: &mut F,
32 inputs: Self::Input,
33 channel: &mut Channel,
34 ) -> Result<Self::Output> {
35 let xs = inputs;
36 let gadget_projection_tt = |p: u16, q: u16| -> Vec<u16> {
37 let pq = p as u32 + q as u32 - 1;
38 let mut tab = Vec::with_capacity(pq as usize);
39 for z in 0..pq {
40 let mut x = 0;
41 let mut y = 0;
42 'outer: for i in 0..p as u32 {
43 for j in 0..q as u32 {
44 if (i + pq - j) % pq == z {
45 x = i;
46 y = j;
47 break 'outer;
48 }
49 }
50 }
51 debug_assert_eq!((x + pq - y) % pq, z);
52 tab.push(
53 (((x * q as u32 * inv(q as i128, p as i128) as u32
54 + y * p as u32 * inv(p as i128, q as i128) as u32)
55 / p as u32)
56 % q as u32) as u16,
57 );
58 }
59 tab
60 };
61
62 let mut gadget = |x: F::Item, y: F::Item| -> Result<F::Item> {
63 let p = x.modulus();
64 let q = y.modulus();
65 let x_ = ModChange.execute(backend, (x, p + q - 1), channel)?;
66 let y_ = ModChange.execute(backend, (y, p + q - 1), channel)?;
67 let z = backend.sub(&x_, &y_);
68 backend.proj(&z, q, Some(gadget_projection_tt(p, q)), channel)
69 };
70
71 let n = xs.size();
72 let mut x = vec![vec![None; n + 1]; n + 1];
73
74 for j in 0..n {
75 x[0][j + 1] = Some(xs.wires()[j].clone());
76 }
77
78 for i in 1..=n {
79 for j in i + 1..=n {
80 let z = gadget(x[i - 1][i].clone().unwrap(), x[i - 1][j].clone().unwrap())?;
81 x[i][j] = Some(z);
82 }
83 }
84
85 let mut zwires = Vec::with_capacity(n);
86 for i in 0..n {
87 zwires.push(x[i][i + 1].take().unwrap());
88 }
89 Ok(CrtBundle::new(zwires))
90 }
91}
92
93#[derive(Default)]
100pub struct PmrLessThan<'a>(PhantomData<&'a ()>);
101
102impl<'a> PmrLessThan<'a> {
103 pub fn new() -> Self {
105 Default::default()
106 }
107}
108
109impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for PmrLessThan<'a>
110where
111 F::Item: 'a,
112{
113 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
114 type Output = F::Item;
115
116 fn execute(
117 &self,
118 backend: &mut F,
119 inputs: Self::Input,
120 channel: &mut Channel,
121 ) -> Result<Self::Output> {
122 let z = Subtraction::new().execute(backend, inputs, channel)?;
123 let mut pmr = ToPmr::new().execute(backend, &z, channel)?;
124 let w = pmr.pop().unwrap();
125 let mut tab = vec![1; w.modulus() as usize];
126 tab[0] = 0;
127 backend.proj(&w, 2, Some(tab), channel)
128 }
129}
130
131#[derive(Default)]
138pub struct PmrGreaterThanOrEqual<'a>(PhantomData<&'a ()>);
139
140impl<'a> PmrGreaterThanOrEqual<'a> {
141 pub fn new() -> Self {
143 Default::default()
144 }
145}
146
147impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for PmrGreaterThanOrEqual<'a>
148where
149 F::Item: 'a,
150{
151 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>);
152 type Output = F::Item;
153
154 fn execute(
155 &self,
156 backend: &mut F,
157 inputs: Self::Input,
158 channel: &mut Channel,
159 ) -> Result<Self::Output> {
160 let z = PmrLessThan::new().execute(backend, inputs, channel)?;
161 Ok(backend.negate(&z))
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use rand::Rng;
168
169 use crate::{
170 circuits::arithmetic::{
171 ToPmr,
172 pmr::{PmrGreaterThanOrEqual, PmrLessThan},
173 },
174 dummy::{Dummy, DummyVal},
175 util::{RngExt, product},
176 };
177
178 #[test]
179 fn to_pmr() {
180 fn to_pmr_pt(x: u128, ps: &[u16]) -> Vec<u16> {
181 let mut ds = vec![0; ps.len()];
182 let mut q = 1;
183 for i in 0..ps.len() {
184 let p = ps[i] as u128;
185 ds[i] = ((x / q) % p) as u16;
186 q *= p;
187 }
188 ds
189 }
190
191 let mut rng = rand::thread_rng();
192 for _ in 0..8 {
193 let ps = rng.gen_usable_factors();
194 let q = product(&ps);
195
196 let x = rng.r#gen::<u128>() % q;
197 let expected = to_pmr_pt(x, &ps);
198
199 let x_input = DummyVal::to_crt(x, q);
200 let z = Dummy::eval(&ToPmr::new(), &x_input).unwrap();
201 let output = z.wires().iter().map(|w| w.val()).collect::<Vec<_>>();
202 assert_eq!(output, expected);
203 }
204 }
205
206 #[test]
207 fn pmr_less_than() {
208 let mut rng = rand::thread_rng();
209 for _ in 0..8 {
210 let qs = rng.gen_usable_factors();
211 let n = qs.len();
212 let q = product(&qs);
213 let q_ = product(&qs[..n - 1]);
214 let x = rng.r#gen::<u128>() % q_;
215 let y = rng.r#gen::<u128>() % q_;
216
217 let x_input = DummyVal::to_crt(x, q);
218 let y_input = DummyVal::to_crt(y, q);
219 let output = Dummy::eval(&PmrLessThan::new(), (&x_input, &y_input)).unwrap();
220 assert_eq!(output.val(), (x < y) as u16);
221 }
222 }
223
224 #[test]
225 fn pmr_greater_than_or_equal() {
226 let mut rng = rand::thread_rng();
227 for _ in 0..8 {
228 let qs = rng.gen_usable_factors();
229 let n = qs.len();
230 let q = product(&qs);
231 let q_ = product(&qs[..n - 1]);
232 let x = rng.gen_u128() % q_;
233 let y = rng.gen_u128() % q_;
234
235 let x_input = DummyVal::to_crt(x, q);
236 let y_input = DummyVal::to_crt(y, q);
237 let output = Dummy::eval(&PmrGreaterThanOrEqual::new(), (&x_input, &y_input)).unwrap();
238 assert_eq!(output.val(), (x >= y) as u16);
239 }
240 }
241}