fedimint_fountain/
fountain.rs

1use std::collections::BTreeMap;
2
3use bitcoin_hashes::Hash;
4use fedimint_core::encoding::{Decodable, Encodable};
5use rand::distributions::{Distribution, WeightedIndex};
6use rand::seq::IteratorRandom;
7use rand_chacha::ChaCha20Rng;
8use rand_chacha::rand_core::SeedableRng;
9
10fn checksum(data: &[u8]) -> [u8; 4] {
11    bitcoin_hashes::sha256::Hash::hash(data).to_byte_array()[..4]
12        .try_into()
13        .unwrap()
14}
15
16#[derive(Debug)]
17pub enum Error {
18    /// Received fragment is invalid.
19    InvalidFragment,
20    /// Received fragment is inconsistent with previous ones.
21    InconsistentFragment,
22}
23
24impl core::fmt::Display for Error {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        match self {
27            Self::InvalidFragment => write!(f, "received invalid fragment"),
28            Self::InconsistentFragment => write!(f, "fragment is inconsistent with previous ones"),
29        }
30    }
31}
32
33/// An encoder capable of emitting fountain-encoded transmissions.
34#[derive(Debug)]
35pub struct Encoder {
36    fragments: Vec<Vec<u8>>,
37    message_length: usize,
38    checksum: [u8; 4],
39    index: u32,
40}
41
42impl Encoder {
43    pub fn new(message: &[u8], max_fragment_length: usize) -> Self {
44        assert!(!message.is_empty());
45        assert!(max_fragment_length > 0);
46
47        let fragment_length = fragment_length(message.len(), max_fragment_length);
48
49        let fragments = partition(message.to_vec(), fragment_length);
50
51        Self {
52            fragments,
53            message_length: message.len(),
54            checksum: checksum(message),
55            index: 0,
56        }
57    }
58
59    /// Returns the next fragment to be emitted by the fountain encoder.
60    /// After all fragments of the original message have been emitted once,
61    /// the fountain encoder will emit the result of xoring together the
62    /// fragments selected by the Xoshiro RNG (which could be a single
63    /// fragment).
64    pub fn next_fragment(&mut self) -> Fragment {
65        let index = self.index;
66
67        self.index += 1;
68
69        let indexes = choose_fragments(self.fragments.len(), self.checksum, index);
70
71        let mut data = vec![0; self.fragments[0].len()];
72
73        for item in indexes {
74            xor(&mut data, &self.fragments[item]);
75        }
76
77        Fragment {
78            meta: EncodingMetadata::new(self.fragments.len(), self.message_length, self.checksum),
79            index,
80            data,
81        }
82    }
83}
84
85pub const fn fragment_length(data_length: usize, max_fragment_length: usize) -> usize {
86    data_length.div_ceil(data_length.div_ceil(max_fragment_length))
87}
88
89pub fn partition(mut data: Vec<u8>, fragment_length: usize) -> Vec<Vec<u8>> {
90    let mut padding = vec![0; (fragment_length - (data.len() % fragment_length)) % fragment_length];
91
92    data.append(&mut padding);
93
94    data.chunks(fragment_length).map(<[u8]>::to_vec).collect()
95}
96
97fn choose_fragments(fragment_count: usize, checksum: [u8; 4], index: u32) -> Vec<usize> {
98    if (index as usize) < fragment_count {
99        return vec![index as usize];
100    }
101
102    let seed = (checksum, index).consensus_hash_sha256();
103
104    let mut rng = ChaCha20Rng::from_seed(seed.to_byte_array());
105
106    // Sample degree from Ideal Soliton Distribution: P(degree = k) ∝ 1/k
107    let degree = WeightedIndex::new((0..fragment_count).map(|x| 1.0 / (x + 1) as f64))
108        .unwrap()
109        .sample(&mut rng)
110        + 1;
111
112    // Choose degree random fragments
113    (0..fragment_count).choose_multiple(&mut rng, degree)
114}
115
116fn xor(v1: &mut [u8], v2: &[u8]) {
117    assert_eq!(v1.len(), v2.len());
118
119    for (x1, &x2) in v1.iter_mut().zip(v2.iter()) {
120        *x1 ^= x2;
121    }
122}
123
124/// Encoding metadata for a fragment.
125#[derive(Clone, Debug, PartialEq, Eq, Encodable, Decodable)]
126pub struct EncodingMetadata {
127    simple_fragments: u32,
128    message_length: u32,
129    checksum: [u8; 4],
130}
131
132impl EncodingMetadata {
133    pub fn new(simple_fragments: usize, message_length: usize, checksum: [u8; 4]) -> Self {
134        Self {
135            simple_fragments: simple_fragments as u32,
136            message_length: message_length as u32,
137            checksum,
138        }
139    }
140
141    pub fn fragment_length(&self) -> usize {
142        self.message_length().div_ceil(self.simple_fragments())
143    }
144
145    pub fn message_length(&self) -> usize {
146        self.message_length as usize
147    }
148
149    pub fn checksum(&self) -> [u8; 4] {
150        self.checksum
151    }
152
153    fn simple_fragments(&self) -> usize {
154        self.simple_fragments as usize
155    }
156
157    fn verify(&self) -> bool {
158        self.simple_fragments() > 0 && self.message_length() > 0 && self.fragment_length() > 0
159    }
160}
161
162/// A fragment emitted by a fountain [`Encoder`].
163#[derive(Clone, Debug, PartialEq, Eq, Encodable, Decodable)]
164pub struct Fragment {
165    meta: EncodingMetadata,
166    index: u32,
167    data: Vec<u8>,
168}
169
170impl Fragment {
171    /// Returns the indexes of the message segments that were combined.
172    pub fn indexes(&self) -> Vec<usize> {
173        choose_fragments(
174            self.meta.simple_fragments(),
175            self.meta.checksum(),
176            self.index,
177        )
178    }
179}
180
181/// A decoder capable of receiving and recombining fountain-encoded
182/// transmissions.
183#[derive(Default)]
184pub struct Decoder {
185    decoded: BTreeMap<usize, Vec<u8>>,
186    buffer: BTreeMap<Vec<usize>, Vec<u8>>,
187    meta: Option<EncodingMetadata>,
188}
189
190impl Decoder {
191    /// If the message is available, returns it, `None` otherwise.
192    pub fn message(&self) -> Option<Vec<u8>> {
193        if self.decoded.len() < self.meta.as_ref()?.simple_fragments() {
194            return None;
195        }
196
197        let message = self
198            .decoded
199            .values()
200            .flat_map(|data| data.clone())
201            .take(self.meta.as_ref()?.message_length())
202            .collect();
203
204        Some(message)
205    }
206
207    /// Receives a fountain-encoded fragment into the decoder.
208    pub fn receive(&mut self, fragment: Fragment) -> Result<Option<Vec<u8>>, Error> {
209        if let Some(message) = self.message() {
210            return Ok(Some(message));
211        }
212
213        if !fragment.meta.verify() {
214            return Err(Error::InvalidFragment);
215        }
216
217        if fragment.data.len() != fragment.meta.fragment_length() {
218            return Err(Error::InvalidFragment);
219        }
220
221        match self.meta.as_ref() {
222            None => {
223                self.meta = Some(fragment.meta.clone());
224            }
225            Some(meta) => {
226                if meta != &fragment.meta {
227                    return Err(Error::InconsistentFragment);
228                }
229            }
230        }
231
232        if let [index] = fragment.indexes().as_slice() {
233            self.process_simple(*index, fragment.data.clone());
234        } else {
235            self.process_complex(fragment.indexes(), fragment.data.clone());
236        }
237
238        Ok(self.message())
239    }
240
241    fn process_simple(&mut self, index: usize, data: Vec<u8>) {
242        self.decoded.insert(index, data.clone());
243
244        let mut queue = self.decoded.clone().into_iter().collect::<Vec<_>>();
245
246        while let Some((index, simple)) = queue.pop() {
247            for (mut indexes, mut data) in self
248                .buffer
249                .clone()
250                .into_iter()
251                .filter(|entry| entry.0.contains(&index))
252            {
253                self.buffer.remove(&indexes).unwrap();
254
255                indexes.retain(|&i| i != index);
256
257                xor(&mut data, &simple);
258
259                if let [index] = indexes.as_slice() {
260                    self.decoded.insert(*index, data.clone());
261                    queue.push((*index, data));
262                } else {
263                    self.buffer.insert(indexes, data);
264                }
265            }
266        }
267    }
268
269    fn process_complex(&mut self, mut indexes: Vec<usize>, mut data: Vec<u8>) {
270        let to_remove: Vec<usize> = indexes
271            .clone()
272            .into_iter()
273            .filter(|i| self.decoded.keys().any(|k| k == i))
274            .collect();
275
276        if indexes.len() == to_remove.len() {
277            return;
278        }
279
280        for remove in &to_remove {
281            xor(&mut data, self.decoded.get(remove).unwrap());
282        }
283
284        indexes.retain(|&i| !to_remove.contains(&i));
285
286        if let [index] = indexes.as_slice() {
287            self.decoded.insert(*index, data.clone());
288        } else {
289            self.buffer.insert(indexes, data);
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_fragment_length() {
300        assert_eq!(fragment_length(12345, 1955), 1764);
301        assert_eq!(fragment_length(12345, 30000), 12345);
302
303        assert_eq!(fragment_length(10, 4), 4);
304        assert_eq!(fragment_length(10, 5), 5);
305        assert_eq!(fragment_length(10, 6), 5);
306        assert_eq!(fragment_length(10, 10), 10);
307    }
308
309    #[test]
310    #[should_panic(expected = "assertion failed")]
311    fn test_fountain_encoder_zero_max_length() {
312        Encoder::new(b"foo", 0);
313    }
314
315    #[test]
316    #[should_panic(expected = "assertion failed")]
317    fn test_empty_encoder() {
318        Encoder::new(&[], 1);
319    }
320
321    #[test]
322    fn test_decoder_fragment_validation() {
323        let mut encoder1 = Encoder::new(b"foo", 2);
324        let mut encoder2 = Encoder::new(b"bar", 2);
325        let mut decoder = Decoder::default();
326
327        // Receive first fragment from encoder1 - not complete yet
328        assert_eq!(decoder.receive(encoder1.next_fragment()).unwrap(), None);
329
330        // Try to receive fragment from encoder2 with different metadata - should reject
331        assert!(matches!(
332            decoder.receive(encoder2.next_fragment()),
333            Err(Error::InconsistentFragment)
334        ));
335
336        // Receiving another fragment from encoder1 should work and complete
337        assert_eq!(
338            decoder.receive(encoder1.next_fragment()).unwrap(),
339            Some(b"foo".to_vec())
340        );
341    }
342
343    #[test]
344    fn test_empty_decoder_empty_fragment() {
345        let mut decoder = Decoder::default();
346        let mut fragment = Fragment {
347            meta: EncodingMetadata::new(8, 100, [0x12, 0x34, 0x56, 0x78]),
348            index: 12,
349            data: vec![1, 5, 3, 3, 5],
350        };
351
352        // Check simple_fragments.
353        fragment.meta.simple_fragments = 0;
354        assert!(matches!(
355            decoder.receive(fragment.clone()),
356            Err(Error::InvalidFragment)
357        ));
358        fragment.meta.simple_fragments = 8;
359
360        // Check message_length.
361        fragment.meta.message_length = 0;
362        assert!(matches!(
363            decoder.receive(fragment.clone()),
364            Err(Error::InvalidFragment)
365        ));
366        fragment.meta.message_length = 100;
367
368        // Check data.
369        fragment.data = vec![];
370        assert!(matches!(
371            decoder.receive(fragment.clone()),
372            Err(Error::InvalidFragment)
373        ));
374    }
375}