1use std::marker::PhantomData;
2
3use crate::{
4 AllWire, ArithmeticWire, FancyArithmetic, FancyBinary, HasModulus, WireMod2, check_binary,
5 errors::{EvaluatorError, FancyError},
6 fancy::{Fancy, FancyReveal},
7 hash_wires,
8 util::{output_tweak, tweak, tweak2},
9 wire::WireLabel,
10};
11use subtle::ConditionallySelectable;
12use swanky_block::Block;
13use swanky_channel_legacy::AbstractChannel;
14
15use super::security_warning::warn_proj;
16
17pub struct Evaluator<C, Wire> {
22 pub(crate) channel: C,
23 current_gate: usize,
24 current_output: usize,
25 _phantom: PhantomData<Wire>,
26}
27
28impl<C: AbstractChannel, Wire: WireLabel> Evaluator<C, Wire> {
29 pub fn new(channel: C) -> Self {
31 Evaluator {
32 channel,
33 current_gate: 0,
34 current_output: 0,
35 _phantom: PhantomData,
36 }
37 }
38
39 fn current_gate(&mut self) -> usize {
41 let current = self.current_gate;
42 self.current_gate += 1;
43 current
44 }
45
46 fn current_output(&mut self) -> usize {
48 let current = self.current_output;
49 self.current_output += 1;
50 current
51 }
52
53 pub fn read_wire(&mut self, modulus: u16) -> Result<Wire, EvaluatorError> {
55 let block = self.channel.read_block()?;
56 Ok(Wire::from_block(block, modulus))
57 }
58
59 fn evaluate_and_gate(
65 &mut self,
66 A: &WireMod2,
67 B: &WireMod2,
68 gate0: &Block,
69 gate1: &Block,
70 ) -> WireMod2 {
71 let gate_num = self.current_gate();
72 let g = tweak2(gate_num as u64, 0);
73
74 let [hashA, hashB] = hash_wires([A, B], g);
75
76 let L = WireMod2::from_block(
78 Block::conditional_select(&hashA, &(hashA ^ *gate0), (A.color() as u8).into()),
79 2,
80 );
81
82 let R = WireMod2::from_block(
84 Block::conditional_select(&hashB, &(hashB ^ *gate1), (B.color() as u8).into()),
85 2,
86 );
87
88 L.plus_mov(&R.plus_mov(&A.cmul(B.color())))
89 }
90}
91
92impl<C: AbstractChannel> FancyBinary for Evaluator<C, WireMod2> {
93 fn negate(&mut self, x: &Self::Item) -> Result<Self::Item, Self::Error> {
95 Ok(*x)
96 }
97
98 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Result<Self::Item, Self::Error> {
99 Ok(x.plus(y))
100 }
101
102 fn and(&mut self, A: &Self::Item, B: &Self::Item) -> Result<Self::Item, Self::Error> {
103 let gate0 = self.channel.read_block()?;
104 let gate1 = self.channel.read_block()?;
105 Ok(self.evaluate_and_gate(A, B, &gate0, &gate1))
106 }
107}
108
109impl<C: AbstractChannel, Wire: WireLabel> FancyReveal for Evaluator<C, Wire> {
110 fn reveal(&mut self, x: &Wire) -> Result<u16, EvaluatorError> {
111 let val = self.output(x)?.expect("Evaluator always outputs Some(u16)");
112 self.channel.write_u16(val)?;
113 self.channel.flush()?;
114 Ok(val)
115 }
116}
117
118impl<C: AbstractChannel> FancyBinary for Evaluator<C, AllWire> {
119 fn negate(&mut self, x: &Self::Item) -> Result<Self::Item, Self::Error> {
121 check_binary!(x);
122
123 Ok(x.clone())
124 }
125
126 fn xor(&mut self, x: &Self::Item, y: &Self::Item) -> Result<Self::Item, Self::Error> {
127 check_binary!(x);
128 check_binary!(y);
129
130 self.add(x, y)
131 }
132
133 fn and(&mut self, x: &Self::Item, y: &Self::Item) -> Result<Self::Item, Self::Error> {
134 if let (AllWire::Mod2(A), AllWire::Mod2(B)) = (x, y) {
135 let gate0 = self.channel.read_block()?;
136 let gate1 = self.channel.read_block()?;
137 return Ok(AllWire::Mod2(self.evaluate_and_gate(A, B, &gate0, &gate1)));
138 }
139
140 check_binary!(x);
142 check_binary!(y);
143
144 unreachable!()
146 }
147}
148
149impl<C: AbstractChannel, Wire: WireLabel + ArithmeticWire> FancyArithmetic for Evaluator<C, Wire> {
150 fn add(&mut self, x: &Wire, y: &Wire) -> Result<Wire, EvaluatorError> {
151 if x.modulus() != y.modulus() {
152 return Err(EvaluatorError::FancyError(FancyError::UnequalModuli));
153 }
154 Ok(x.plus(y))
155 }
156
157 fn sub(&mut self, x: &Wire, y: &Wire) -> Result<Wire, EvaluatorError> {
158 if x.modulus() != y.modulus() {
159 return Err(EvaluatorError::FancyError(FancyError::UnequalModuli));
160 }
161 Ok(x.minus(y))
162 }
163
164 fn cmul(&mut self, x: &Wire, c: u16) -> Result<Wire, EvaluatorError> {
165 Ok(x.cmul(c))
166 }
167
168 fn mul(&mut self, A: &Wire, B: &Wire) -> Result<Wire, EvaluatorError> {
169 if A.modulus() < B.modulus() {
170 return self.mul(B, A);
171 }
172 let q = A.modulus();
173 let qb = B.modulus();
174 let unequal = q != qb;
175 let ngates = q as usize + qb as usize - 2 + unequal as usize;
176 let mut gate = Vec::with_capacity(ngates);
177 {
178 for _ in 0..ngates {
179 let block = self.channel.read_block()?;
180 gate.push(block);
181 }
182 }
183 let gate_num = self.current_gate();
184 let g = tweak2(gate_num as u64, 0);
185
186 let [hashA, hashB] = hash_wires([A, B], g);
187
188 let L = if A.color() == 0 {
190 Wire::hash_to_mod(hashA, q)
191 } else {
192 let ct_left = gate[A.color() as usize - 1];
193 Wire::from_block(ct_left ^ hashA, q)
194 };
195
196 let R = if B.color() == 0 {
198 Wire::hash_to_mod(hashB, q)
199 } else {
200 let ct_right = gate[(q + B.color()) as usize - 2];
201 Wire::from_block(ct_right ^ hashB, q)
202 };
203
204 let new_b_color = if unequal {
207 let minitable = *gate.last().unwrap();
208 let ct = u128::from(minitable) >> (B.color() * 16);
209 let pt = u128::from(B.hash(tweak2(gate_num as u64, 1))) ^ ct;
210 pt as u16
211 } else {
212 B.color()
213 };
214
215 let res = L.plus_mov(&R.plus_mov(&A.cmul(new_b_color)));
216 Ok(res)
217 }
218
219 fn proj(&mut self, x: &Wire, q: u16, _: Option<Vec<u16>>) -> Result<Wire, EvaluatorError> {
220 warn_proj();
221 let ngates = (x.modulus() - 1) as usize;
222 let mut gate = Vec::with_capacity(ngates);
223 for _ in 0..ngates {
224 let block = self.channel.read_block()?;
225 gate.push(block);
226 }
227 let t = tweak(self.current_gate());
228 if x.color() == 0 {
229 Ok(x.hashback(t, q))
230 } else {
231 let ct = gate[x.color() as usize - 1];
232 Ok(Wire::from_block(ct ^ x.hash(t), q))
233 }
234 }
235}
236
237impl<C: AbstractChannel, Wire: WireLabel> Fancy for Evaluator<C, Wire> {
238 type Item = Wire;
239 type Error = EvaluatorError;
240
241 fn constant(&mut self, _: u16, q: u16) -> Result<Wire, EvaluatorError> {
242 self.read_wire(q)
243 }
244
245 fn output(&mut self, x: &Wire) -> Result<Option<u16>, EvaluatorError> {
246 let q = x.modulus();
247 let i = self.current_output();
248
249 let ct = self.channel.read_blocks(q as usize)?;
251
252 let mut decoded = None;
254 for k in 0..q {
255 let hashed_wire = x.hash(output_tweak(i, k));
256 if hashed_wire == ct[k as usize] {
257 decoded = Some(k);
258 break;
259 }
260 }
261
262 if let Some(output) = decoded {
263 Ok(Some(output))
264 } else {
265 Err(EvaluatorError::DecodingFailed)
266 }
267 }
268}