1use crate::{
2 CrtBundle, FancyArithmetic, FancyBinary, FancyProj,
3 circuit::Circuit,
4 circuits::arithmetic::{FractionalMixedRadix, Subtraction},
5 util::get_ms,
6};
7use core::marker::PhantomData;
8use swanky_channel::Channel;
9use swanky_error::Result;
10
11#[derive(Default)]
13pub struct Sign<'a>(PhantomData<&'a ()>);
14
15impl<'a> Sign<'a> {
16 pub fn new() -> Self {
18 Default::default()
19 }
20}
21
22impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for Sign<'a>
23where
24 F::Item: 'a,
25{
26 type Input = (&'a CrtBundle<F::Item>, &'a str);
27 type Output = F::Item;
28
29 fn execute(
30 &self,
31 backend: &mut F,
32 inputs: Self::Input,
33 channel: &mut Channel,
34 ) -> Result<Self::Output> {
35 let (x, accuracy) = inputs;
36 let factors_of_m = get_ms(x, accuracy);
37 let res = FractionalMixedRadix::new().execute(backend, (x, &factors_of_m), channel)?;
38 let p = factors_of_m.last().unwrap();
39 let tt = (0..*p).map(|x| (x >= p / 2) as u16).collect::<Vec<_>>();
40 backend.proj(&res, 2, Some(tt), channel)
41 }
42}
43
44#[derive(Default)]
46pub struct LessThan<'a>(PhantomData<&'a ()>);
47
48impl<'a> LessThan<'a> {
49 pub fn new() -> Self {
51 Default::default()
52 }
53}
54
55impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for LessThan<'a>
56where
57 F::Item: 'a,
58{
59 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>, &'a str);
60 type Output = F::Item;
61
62 fn execute(
63 &self,
64 backend: &mut F,
65 inputs: Self::Input,
66 channel: &mut Channel,
67 ) -> Result<Self::Output> {
68 let (x, y, accuracy) = inputs;
69 let z = Subtraction::new().execute(backend, (x, y), channel)?;
70 Sign::new().execute(backend, (&z, accuracy), channel)
71 }
72}
73
74#[derive(Default)]
76pub struct GreaterThanOrEqual<'a>(PhantomData<&'a ()>);
77
78impl<'a> GreaterThanOrEqual<'a> {
79 pub fn new() -> Self {
81 Default::default()
82 }
83}
84
85impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for GreaterThanOrEqual<'a>
86where
87 F::Item: 'a,
88{
89 type Input = (&'a CrtBundle<F::Item>, &'a CrtBundle<F::Item>, &'a str);
90 type Output = F::Item;
91
92 fn execute(
93 &self,
94 backend: &mut F,
95 inputs: Self::Input,
96 channel: &mut Channel,
97 ) -> Result<Self::Output> {
98 let z = LessThan::new().execute(backend, inputs, channel)?;
99 Ok(backend.negate(&z))
100 }
101}
102
103#[derive(Default)]
105pub struct Max<'a>(PhantomData<&'a ()>);
106
107impl<'a> Max<'a> {
108 pub fn new() -> Self {
110 Default::default()
111 }
112}
113
114impl<'a, F: FancyBinary + FancyArithmetic + FancyProj> Circuit<F> for Max<'a>
115where
116 F::Item: 'a,
117{
118 type Input = (&'a [CrtBundle<F::Item>], &'a str);
119 type Output = CrtBundle<F::Item>;
120
121 fn execute(
122 &self,
123 backend: &mut F,
124 inputs: Self::Input,
125 channel: &mut Channel,
126 ) -> Result<Self::Output> {
127 let (xs, accuracy) = inputs;
128 assert!(!xs.is_empty(), "`xs` cannot be empty");
129
130 xs.iter().skip(1).try_fold(xs[0].clone(), |x, y| {
131 let pos = LessThan::new().execute(backend, (&x, y, accuracy), channel)?;
132 let neg = backend.negate(&pos);
133 Ok(CrtBundle::new(
134 x.wires()
135 .iter()
136 .zip(y.wires().iter())
137 .map(|(x, y)| {
138 let xp = backend.mul(x, &neg, channel)?;
139 let yp = backend.mul(y, &pos, channel)?;
140 Ok(backend.add(&xp, &yp))
141 })
142 .collect::<Result<Vec<_>>>()?,
143 ))
144 })
145 }
146}
147
148#[derive(Default)]
153pub struct Sgn<'a>(PhantomData<&'a ()>);
154
155impl<'a> Sgn<'a> {
156 pub fn new() -> Self {
158 Default::default()
159 }
160}
161
162impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for Sgn<'a>
163where
164 F::Item: 'a,
165{
166 type Input = (&'a CrtBundle<F::Item>, &'a str, Option<&'a [u16]>);
167 type Output = CrtBundle<F::Item>;
168
169 fn execute(
170 &self,
171 backend: &mut F,
172 inputs: Self::Input,
173 channel: &mut Channel,
174 ) -> Result<Self::Output> {
175 let (x, accuracy, output_moduli) = inputs;
176 let sign = Sign::new().execute(backend, (x, accuracy), channel)?;
177 output_moduli
178 .map(|m| m.to_vec())
179 .unwrap_or_else(|| x.moduli())
180 .iter()
181 .map(|&p| {
182 let tt = vec![1, p - 1];
183 backend.proj(&sign, p, Some(tt), channel)
184 })
185 .collect::<Result<_>>()
186 .map(CrtBundle::new)
187 }
188}
189
190#[derive(Default)]
192pub struct ReLU<'a>(PhantomData<&'a ()>);
193
194impl<'a> ReLU<'a> {
195 pub fn new() -> Self {
197 Default::default()
198 }
199}
200
201impl<'a, F: FancyArithmetic + FancyProj> Circuit<F> for ReLU<'a>
202where
203 F::Item: 'a,
204{
205 type Input = (&'a CrtBundle<F::Item>, &'a str, Option<&'a [u16]>);
206 type Output = CrtBundle<F::Item>;
207
208 fn execute(
209 &self,
210 backend: &mut F,
211 inputs: Self::Input,
212 channel: &mut Channel,
213 ) -> Result<Self::Output> {
214 let (x, accuracy, output_moduli) = inputs;
215 let factors_of_m = get_ms(x, accuracy);
216 let res = FractionalMixedRadix::new().execute(backend, (x, &factors_of_m), channel)?;
217
218 let p = *factors_of_m.last().unwrap();
220 let mask_tt = (0..p).map(|x| (x < p / 2) as u16).collect::<Vec<_>>();
221 let mask = backend.proj(&res, 2, Some(mask_tt), channel)?;
222
223 let output_bundle = match output_moduli {
225 Some(ps) => x.with_moduli(ps),
226 None => (*x).clone().extract(),
227 };
228
229 output_bundle
230 .wires()
231 .iter()
232 .map(|x| backend.mul(x, &mask, channel))
233 .collect::<Result<_>>()
234 .map(CrtBundle::new)
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use rand::{Rng, thread_rng};
241
242 use crate::{
243 circuits::arithmetic::{
244 ReLU, Sgn,
245 comparison::{GreaterThanOrEqual, LessThan, Max, Sign},
246 },
247 dummy::{Dummy, DummyVal},
248 util::modulus_with_width,
249 };
250
251 #[test]
252 fn sign() {
253 let mut rng = thread_rng();
254 let accuracy = "100%";
255 let q = modulus_with_width(10);
256
257 let x = 0;
259 let x_input = DummyVal::to_crt(x, q);
260 let output = Dummy::eval(&Sign::new(), (&x_input, accuracy)).unwrap();
261 assert_eq!(output.val(), if x < q / 2 { 0 } else { 1 });
262
263 for _ in 0..64 {
264 let x = rng.r#gen::<u128>() % q;
265 let x_input = DummyVal::to_crt(x, q);
266 let output = Dummy::eval(&Sign::new(), (&x_input, accuracy)).unwrap();
267 assert_eq!(output.val(), if x < q / 2 { 0 } else { 1 });
268 }
269 }
270
271 #[test]
272 fn less_than() {
273 let mut rng = thread_rng();
274 let accuracy = "100%";
275 let q = modulus_with_width(10);
276
277 let x = rng.r#gen::<u128>() % q / 2;
279 let x_input = DummyVal::to_crt(x, q);
280 let output = Dummy::eval(&LessThan::new(), (&x_input, &x_input, accuracy)).unwrap();
281 assert_eq!(output.val(), (x < x) as u16);
282
283 for _ in 0..64 {
284 let x = rng.r#gen::<u128>() % q / 2;
285 let y = rng.r#gen::<u128>() % q / 2;
286 let x_input = DummyVal::to_crt(x, q);
287 let y_input = DummyVal::to_crt(y, q);
288 let output = Dummy::eval(&LessThan::new(), (&x_input, &y_input, accuracy)).unwrap();
289 assert_eq!(output.val(), (x < y) as u16);
290 }
291 }
292
293 #[test]
294 fn greater_than_or_equal() {
295 let mut rng = thread_rng();
296 let accuracy = "100%";
297 let q = modulus_with_width(10);
298
299 let x = rng.r#gen::<u128>() % q / 2;
301 let x_input = DummyVal::to_crt(x, q);
302 let output =
303 Dummy::eval(&GreaterThanOrEqual::new(), (&x_input, &x_input, accuracy)).unwrap();
304 assert_eq!(output.val(), (x >= x) as u16);
305
306 for _ in 0..64 {
307 let x = rng.r#gen::<u128>() % q / 2;
308 let y = rng.r#gen::<u128>() % q / 2;
309 let x_input = DummyVal::to_crt(x, q);
310 let y_input = DummyVal::to_crt(y, q);
311 let output =
312 Dummy::eval(&GreaterThanOrEqual::new(), (&x_input, &y_input, accuracy)).unwrap();
313 assert_eq!(output.val(), (x >= y) as u16);
314 }
315 }
316
317 #[test]
318 fn max() {
319 let mut rng = thread_rng();
320 let accuracy = "100%";
321 let q = modulus_with_width(10);
322
323 for _ in 0..16 {
324 let inputs = (0..100)
325 .map(|_| rng.r#gen::<u128>() % (q / 2))
326 .collect::<Vec<_>>();
327 let expected = *inputs.iter().max().unwrap();
328
329 let inputs = inputs
330 .into_iter()
331 .map(|x| DummyVal::to_crt(x, q))
332 .collect::<Vec<_>>();
333 let z = Dummy::eval(&Max::new(), (&inputs[..], accuracy)).unwrap();
334 let output = DummyVal::from_crt(&z, q);
335 assert_eq!(output, expected);
336 }
337 }
338
339 #[test]
340 fn sgn() {
341 let mut rng = thread_rng();
342 let accuracy = "100%";
343 let q = modulus_with_width(10);
344
345 let x = 0;
347 let x_input = DummyVal::to_crt(x, q);
348 let z = Dummy::eval(&Sgn::new(), (&x_input, accuracy, None)).unwrap();
349 let output = DummyVal::from_crt(&z, q);
350 assert_eq!(output, if x < q / 2 { 1 } else { q - 1 });
351
352 for _ in 0..64 {
353 let x = rng.r#gen::<u128>() % q;
354 let x_input = DummyVal::to_crt(x, q);
355 let z = Dummy::eval(&Sgn::new(), (&x_input, accuracy, None)).unwrap();
356 let output = DummyVal::from_crt(&z, q);
357 assert_eq!(output, if x < q / 2 { 1 } else { q - 1 });
358 }
359 }
360
361 #[test]
362 fn relu() {
363 let mut rng = thread_rng();
364 let accuracy = "100%";
365 let q = modulus_with_width(10);
366
367 let x = 0;
369 let x_input = DummyVal::to_crt(x, q);
370 let z = Dummy::eval(&ReLU::new(), (&x_input, accuracy, None)).unwrap();
371 let output = DummyVal::from_crt(&z, q);
372 assert_eq!(output, if x < q / 2 { x } else { 0 });
373
374 for _ in 0..64 {
375 let x = rng.r#gen::<u128>() % q;
376 let x_input = DummyVal::to_crt(x, q);
377 let z = Dummy::eval(&ReLU::new(), (&x_input, accuracy, None)).unwrap();
378 let output = DummyVal::from_crt(&z, q);
379 assert_eq!(output, if x < q / 2 { x } else { 0 });
380 }
381 }
382}