1use super::*;
2use crate::util;
3use itertools::Itertools;
4
5pub trait FancyInput {
7 type Item: Clone + HasModulus;
9
10 fn encode_many(
18 &mut self,
19 values: &[u16],
20 moduli: &[u16],
21 channel: &mut Channel,
22 ) -> swanky_error::Result<Vec<Self::Item>>;
23
24 fn receive_many(
26 &mut self,
27 moduli: &[u16],
28 channel: &mut Channel,
29 ) -> swanky_error::Result<Vec<Self::Item>>;
30
31 fn encode(
39 &mut self,
40 value: u16,
41 modulus: u16,
42 channel: &mut Channel,
43 ) -> swanky_error::Result<Self::Item> {
44 let mut xs = self.encode_many(&[value], &[modulus], channel)?;
45 Ok(xs.remove(0))
46 }
47
48 fn receive(&mut self, modulus: u16, channel: &mut Channel) -> swanky_error::Result<Self::Item> {
50 let mut xs = self.receive_many(&[modulus], channel)?;
51 Ok(xs.remove(0))
52 }
53
54 fn encode_bundle(
56 &mut self,
57 values: &[u16],
58 moduli: &[u16],
59 channel: &mut Channel,
60 ) -> swanky_error::Result<Bundle<Self::Item>> {
61 self.encode_many(values, moduli, channel).map(Bundle::new)
62 }
63
64 fn receive_bundle(
66 &mut self,
67 moduli: &[u16],
68 channel: &mut Channel,
69 ) -> swanky_error::Result<Bundle<Self::Item>> {
70 self.receive_many(moduli, channel).map(Bundle::new)
71 }
72
73 fn encode_bundles(
78 &mut self,
79 values: &[Vec<u16>],
80 moduli: &[Vec<u16>],
81 channel: &mut Channel,
82 ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
83 let qs = moduli.iter().flatten().cloned().collect_vec();
84 let xs = values.iter().flatten().cloned().collect_vec();
85 assert_eq!(xs.len(), qs.len(), "unequal number of values and moduli");
86 let mut wires = self.encode_many(&xs, &qs, channel)?;
87 let buns = moduli
88 .iter()
89 .map(|qs| {
90 let ws = wires.drain(0..qs.len()).collect_vec();
91 Bundle::new(ws)
92 })
93 .collect_vec();
94 Ok(buns)
95 }
96
97 fn receive_many_bundles(
99 &mut self,
100 moduli: &[Vec<u16>],
101 channel: &mut Channel,
102 ) -> swanky_error::Result<Vec<Bundle<Self::Item>>> {
103 let qs = moduli.iter().flatten().cloned().collect_vec();
104 let mut wires = self.receive_many(&qs, channel)?;
105 let buns = moduli
106 .iter()
107 .map(|qs| {
108 let ws = wires.drain(0..qs.len()).collect_vec();
109 Bundle::new(ws)
110 })
111 .collect_vec();
112 Ok(buns)
113 }
114
115 fn crt_encode(
117 &mut self,
118 value: u128,
119 modulus: u128,
120 channel: &mut Channel,
121 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
122 let qs = util::factor(modulus);
123 let xs = util::crt(value, &qs);
124 self.encode_bundle(&xs, &qs, channel).map(CrtBundle::from)
125 }
126
127 fn crt_receive(
129 &mut self,
130 modulus: u128,
131 channel: &mut Channel,
132 ) -> swanky_error::Result<CrtBundle<Self::Item>> {
133 let qs = util::factor(modulus);
134 self.receive_bundle(&qs, channel).map(CrtBundle::from)
135 }
136
137 fn crt_encode_many(
139 &mut self,
140 values: &[u128],
141 modulus: u128,
142 channel: &mut Channel,
143 ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
144 let mods = util::factor(modulus);
145 let nmods = mods.len();
146 let xs = values
147 .iter()
148 .flat_map(|x| util::crt(*x, &mods))
149 .collect_vec();
150 let qs = itertools::repeat_n(mods, values.len())
151 .flatten()
152 .collect_vec();
153 let mut wires = self.encode_many(&xs, &qs, channel)?;
154 let buns = (0..values.len())
155 .map(|_| {
156 let ws = wires.drain(0..nmods).collect_vec();
157 CrtBundle::new(ws)
158 })
159 .collect_vec();
160 Ok(buns)
161 }
162
163 fn crt_receive_many(
165 &mut self,
166 n: usize,
167 modulus: u128,
168 channel: &mut Channel,
169 ) -> swanky_error::Result<Vec<CrtBundle<Self::Item>>> {
170 let mods = util::factor(modulus);
171 let nmods = mods.len();
172 let qs = itertools::repeat_n(mods, n).flatten().collect_vec();
173 let mut wires = self.receive_many(&qs, channel)?;
174 let buns = (0..n)
175 .map(|_| {
176 let ws = wires.drain(0..nmods).collect_vec();
177 CrtBundle::new(ws)
178 })
179 .collect_vec();
180 Ok(buns)
181 }
182
183 fn bin_encode(
185 &mut self,
186 value: u128,
187 nbits: usize,
188 channel: &mut Channel,
189 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
190 let xs = util::u128_to_bits(value, nbits);
191 self.encode_bundle(&xs, &vec![2; nbits], channel)
192 .map(BinaryBundle::from)
193 }
194
195 fn bin_receive(
197 &mut self,
198 nbits: usize,
199 channel: &mut Channel,
200 ) -> swanky_error::Result<BinaryBundle<Self::Item>> {
201 self.receive_bundle(&vec![2; nbits], channel)
202 .map(BinaryBundle::from)
203 }
204
205 fn bin_encode_many(
207 &mut self,
208 values: &[u128],
209 nbits: usize,
210 channel: &mut Channel,
211 ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
212 let xs = values
213 .iter()
214 .flat_map(|x| util::u128_to_bits(*x, nbits))
215 .collect_vec();
216 let mut wires = self.encode_many(&xs, &vec![2; values.len() * nbits], channel)?;
217 let buns = (0..values.len())
218 .map(|_| {
219 let ws = wires.drain(0..nbits).collect_vec();
220 BinaryBundle::new(ws)
221 })
222 .collect_vec();
223 Ok(buns)
224 }
225
226 fn bin_receive_many(
228 &mut self,
229 ninputs: usize,
230 nbits: usize,
231 channel: &mut Channel,
232 ) -> swanky_error::Result<Vec<BinaryBundle<Self::Item>>> {
233 let mut wires = self.receive_many(&vec![2; ninputs * nbits], channel)?;
234 let buns = (0..ninputs)
235 .map(|_| {
236 let ws = wires.drain(0..nbits).collect_vec();
237 BinaryBundle::new(ws)
238 })
239 .collect_vec();
240 Ok(buns)
241 }
242}