1use super::security_warning::warn_proj;
2use crate::{
3 AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, FancyInput, HasModulus, WireMod2,
4 check_binary,
5 fancy::{Fancy, FancyReveal},
6 garble::binary_and::BinaryWireLabel,
7 hash_wires,
8 util::{output_tweak, tweak, tweak2},
9 wire::WireLabel,
10};
11use std::marker::PhantomData;
12use swanky_channel::Channel;
13use swanky_error::ErrorKind;
14use vectoreyes::U8x16;
15
16pub struct Evaluator<Wire> {
21 one: Wire,
22 current_gate: usize,
23 current_output: usize,
24 _phantom: PhantomData<Wire>,
25}
26
27impl<Wire: WireLabel> Evaluator<Wire> {
28 pub fn new(channel: &mut Channel) -> swanky_error::Result<Self> {
30 let one = channel.read::<U8x16>()?;
33 Ok(Evaluator {
34 one: Wire::from_repr(one, 2),
35 current_gate: 0,
36 current_output: 0,
37 _phantom: PhantomData,
38 })
39 }
40
41 fn current_gate(&mut self) -> usize {
43 let current = self.current_gate;
44 self.current_gate += 1;
45 current
46 }
47
48 fn current_output(&mut self) -> usize {
50 let current = self.current_output;
51 self.current_output += 1;
52 current
53 }
54
55 pub fn read_wire(&mut self, modulus: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
57 let block = channel.read()?;
58 Ok(Wire::from_repr(block, modulus))
59 }
60}
61
62impl<W: BinaryWireLabel> FancyBinary for Evaluator<W> {
63 fn negate(&mut self, x: &Self::Item) -> Self::Item {
65 *x + self.one
66 }
67
68 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
69 *x + *y
70 }
71
72 fn and(
73 &mut self,
74 A: &Self::Item,
75 B: &Self::Item,
76 channel: &mut Channel,
77 ) -> swanky_error::Result<Self::Item> {
78 let gate_num = self.current_gate();
79 let gate0 = channel.read()?;
80 let gate1 = channel.read()?;
81 Ok(W::evaluate_and_gate(gate_num, A, B, &gate0, &gate1))
82 }
83}
84
85impl<Wire: BinaryWireLabel> FancyInput for Evaluator<Wire> {
86 type Item = Wire;
87
88 fn encode_many(
89 &mut self,
90 _values: &[u16],
91 _moduli: &[u16],
92 _: &mut Channel,
93 ) -> swanky_error::Result<Vec<Self::Item>> {
94 unimplemented!("Evaluator cannot encode values")
95 }
96
97 fn receive_many(
98 &mut self,
99 moduli: &[u16],
100 channel: &mut Channel,
101 ) -> swanky_error::Result<Vec<Self::Item>> {
102 (0..moduli.len())
103 .map(|_| {
104 let block = channel.read()?;
105 Ok(Wire::from_repr(block, 2))
106 })
107 .collect()
108 }
109}
110
111impl<Wire: WireLabel> FancyReveal for Evaluator<Wire> {
112 fn reveal(&mut self, x: &Wire, channel: &mut Channel) -> swanky_error::Result<u16> {
113 let val = self
114 .output(x, channel)?
115 .expect("Evaluator always outputs Some(u16)");
116 channel.write(&val)?;
117 Ok(val)
118 }
119}
120
121impl FancyBinary for Evaluator<AllWire> {
122 fn negate(&mut self, x: &Self::Item) -> Self::Item {
124 check_binary!(x);
125
126 x.clone() + self.one.clone()
127 }
128
129 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Self::Item {
130 check_binary!(x);
131 check_binary!(y);
132
133 self.add(x, y)
134 }
135
136 fn and(
137 &mut self,
138 x: &Self::Item,
139 y: &Self::Item,
140 channel: &mut Channel,
141 ) -> swanky_error::Result<Self::Item> {
142 if let (AllWire::Mod2(A), AllWire::Mod2(B)) = (x, y) {
143 let gate_num = self.current_gate();
144 let gate0 = channel.read()?;
145 let gate1 = channel.read()?;
146 return Ok(AllWire::Mod2(WireMod2::evaluate_and_gate(
147 gate_num, A, B, &gate0, &gate1,
148 )));
149 }
150
151 check_binary!(x);
153 check_binary!(y);
154
155 unreachable!()
157 }
158}
159
160impl<Wire: WireLabel + ArithmeticWire> FancyArithmetic for Evaluator<Wire> {
161 fn add(&mut self, x: &Wire, y: &Wire) -> Wire {
162 assert_eq!(x.modulus(), y.modulus());
163 x.clone() + y.clone()
164 }
165
166 fn sub(&mut self, x: &Wire, y: &Wire) -> Wire {
167 assert_eq!(x.modulus(), y.modulus());
168 x.clone() - y.clone()
169 }
170
171 fn cmul(&mut self, x: &Wire, c: u16) -> Wire {
172 x.clone() * c
173 }
174
175 fn mul(&mut self, A: &Wire, B: &Wire, channel: &mut Channel) -> swanky_error::Result<Wire> {
176 if A.modulus() < B.modulus() {
177 return self.mul(B, A, channel);
178 }
179 let q = A.modulus();
180 let qb = B.modulus();
181 let unequal = q != qb;
182 let ngates = q as usize + qb as usize - 2 + unequal as usize;
183 let mut gate = Vec::with_capacity(ngates);
184 {
185 for _ in 0..ngates {
186 let block = channel.read::<U8x16>()?;
187 gate.push(block);
188 }
189 }
190 let gate_num = self.current_gate();
191 let g = tweak2(gate_num as u64, 0);
192
193 let [hashA, hashB] = hash_wires([A, B], g);
194
195 let L = if A.color() == 0 {
197 Wire::hash_to_mod(hashA, q)
198 } else {
199 let ct_left = gate[A.color() as usize - 1];
200 Wire::from_repr(ct_left ^ hashA, q)
201 };
202
203 let R = if B.color() == 0 {
205 Wire::hash_to_mod(hashB, q)
206 } else {
207 let ct_right = gate[(q + B.color()) as usize - 2];
208 Wire::from_repr(ct_right ^ hashB, q)
209 };
210
211 let new_b_color = if unequal {
214 let minitable = *gate.last().unwrap();
215 let ct = u128::from(minitable) >> (B.color() * 16);
216 let pt = u128::from(B.hash(tweak2(gate_num as u64, 1))) ^ ct;
217 pt as u16
218 } else {
219 B.color()
220 };
221
222 let res = L + R + A.clone() * new_b_color;
223 Ok(res)
224 }
225
226 fn proj(
227 &mut self,
228 x: &Wire,
229 q: u16,
230 _: Option<Vec<u16>>,
231 channel: &mut Channel,
232 ) -> swanky_error::Result<Wire> {
233 warn_proj();
234 let ngates = (x.modulus() - 1) as usize;
235 let mut gate = Vec::with_capacity(ngates);
236 for _ in 0..ngates {
237 let block = channel.read::<U8x16>()?;
238 gate.push(block);
239 }
240 let t = tweak(self.current_gate());
241 if x.color() == 0 {
242 Ok(x.hashback(t, q))
243 } else {
244 let ct = gate[x.color() as usize - 1];
245 Ok(Wire::from_repr(ct ^ x.hash(t), q))
246 }
247 }
248}
249
250impl<Wire: WireLabel> Fancy for Evaluator<Wire> {
251 type Item = Wire;
252
253 fn constant(&mut self, _: u16, q: u16, channel: &mut Channel) -> swanky_error::Result<Wire> {
254 self.read_wire(q, channel)
255 }
256
257 fn output(&mut self, x: &Wire, channel: &mut Channel) -> swanky_error::Result<Option<u16>> {
258 let q = x.modulus();
259 let i = self.current_output();
260
261 let mut ct = Vec::with_capacity(q as usize);
263 for _ in 0..q {
264 let block = channel.read()?;
265 ct.push(block);
266 }
267
268 let mut decoded = None;
270 for k in 0..q {
271 let hashed_wire = x.hash(output_tweak(i, k));
272 if hashed_wire == ct[k as usize] {
273 decoded = Some(k);
274 break;
275 }
276 }
277
278 if let Some(output) = decoded {
279 Ok(Some(output))
280 } else {
281 swanky_error::bail!(ErrorKind::OtherError, "Decoding failed");
282 }
283 }
284}