Skip to main content

fancy_garbling/circuits/
linear_oram.rs

1use crate::{
2    BinaryBundle, FancyBinary,
3    circuit::{Circuit, CircuitInputMapper},
4    circuits::binary::{BinaryConstant, BinaryEquality, BinaryMultiplex, PairwiseXor},
5};
6use swanky_channel::Channel;
7use swanky_error::Result;
8
9/// Circuit for running linear ORAM.
10///
11/// For a vector of [`BinaryBundle`]s and a single [`BinaryBundle`] query,
12/// output either 0 if no match was found, or the index that matches the query.
13/// Each [`BinaryBundle`] contains `N` bits.
14pub struct LinearOram<const N: usize> {
15    size: usize,
16}
17
18impl<const N: usize> LinearOram<N> {
19    /// Create a new [`LinearOram`] containing `size` elements.
20    pub fn new(size: usize) -> Self {
21        Self { size }
22    }
23}
24
25impl<F: FancyBinary, const N: usize> Circuit<F> for LinearOram<N> {
26    type Input = (Vec<BinaryBundle<F::Item>>, BinaryBundle<F::Item>);
27    type Output = BinaryBundle<F::Item>;
28
29    fn execute(
30        &self,
31        backend: &mut F,
32        inputs: Self::Input,
33        channel: &mut Channel,
34    ) -> Result<Self::Output> {
35        let (ram, query) = inputs;
36        let zero_bit = backend.constant(0, 2, channel)?;
37        let one_bit = backend.constant(1, 2, channel)?;
38
39        let zero =
40            BinaryConstant::new_with_constants(0, N, Some(zero_bit.clone()), Some(one_bit.clone()))
41                .execute(backend, (), channel)?;
42
43        // Traverse the RAM one element at a time, and multiplex the result
44        // based on whether the query matches the current index.
45        let mut result = zero.clone();
46        for (i, item) in ram.iter().enumerate() {
47            let index = BinaryConstant::new_with_constants(
48                i as u128,
49                N,
50                Some(zero_bit.clone()),
51                Some(one_bit.clone()),
52            )
53            .execute(backend, (), channel)?;
54            let is_equal = BinaryEquality::new().execute(backend, (&query, &index), channel)?;
55            let mux = BinaryMultiplex::new().execute(backend, (is_equal, &zero, item), channel)?;
56            // Every `mux` but one will be zero, so we can use `PairwiseXor`
57            // instead of `BinaryAddition`.
58            let xor =
59                PairwiseXor::new().execute(backend, (result.wires(), mux.wires()), channel)?;
60            result = BinaryBundle::new(xor);
61        }
62        Ok(result)
63    }
64}
65
66impl<F: FancyBinary, const N: usize> CircuitInputMapper<F> for LinearOram<N> {
67    fn map(&self, inputs: Vec<<F as crate::Fancy>::Item>) -> Self::Input {
68        assert_eq!(inputs.len(), (self.size + 1) * N);
69        let (ram_bits, query_bits) = inputs.split_at(self.size * N);
70
71        let ram: Vec<BinaryBundle<F::Item>> = ram_bits
72            .chunks(N)
73            .map(|chunk| BinaryBundle::new(chunk.to_vec()))
74            .collect();
75        let query = BinaryBundle::new(query_bits.to_vec());
76
77        (ram, query)
78    }
79
80    fn ninputs(&self) -> usize {
81        (self.size + 1) * N
82    }
83
84    fn modulus(&self, _: usize) -> u16 {
85        2
86    }
87}
88
89#[cfg(test)]
90pub mod test {
91    use crate::{BinaryBundle, circuits::LinearOram};
92
93    #[test]
94    fn linear_oram() {
95        use crate::dummy::{Dummy, DummyVal};
96        use rand::Rng;
97
98        const N: usize = 128;
99        let mut rng = rand::thread_rng();
100        let ram_size = 10;
101        let c = LinearOram::<N>::new(ram_size);
102
103        for _ in 0..16 {
104            let ram: Vec<u128> = (0..ram_size).map(|_| rng.r#gen::<u128>()).collect();
105            let index = rng.r#gen::<usize>() % ram_size;
106
107            let ram_input: Vec<BinaryBundle<DummyVal>> =
108                ram.iter().map(|&val| DummyVal::to_binary(val, N)).collect();
109            let query_input = DummyVal::to_binary(index as u128, N);
110            let output = Dummy::eval(&c, (ram_input, query_input)).unwrap();
111            let result = DummyVal::from_binary(&output);
112            assert_eq!(result, ram[index]);
113        }
114    }
115}