1use std::collections::BTreeMap;
2
3use bitcoin_hashes::Hash;
4use fedimint_core::encoding::{Decodable, Encodable};
5use rand::distributions::{Distribution, WeightedIndex};
6use rand::seq::IteratorRandom;
7use rand_chacha::ChaCha20Rng;
8use rand_chacha::rand_core::SeedableRng;
9
10fn checksum(data: &[u8]) -> [u8; 4] {
11 bitcoin_hashes::sha256::Hash::hash(data).to_byte_array()[..4]
12 .try_into()
13 .unwrap()
14}
15
16#[derive(Debug)]
17pub enum Error {
18 InvalidFragment,
20 InconsistentFragment,
22}
23
24impl core::fmt::Display for Error {
25 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26 match self {
27 Self::InvalidFragment => write!(f, "received invalid fragment"),
28 Self::InconsistentFragment => write!(f, "fragment is inconsistent with previous ones"),
29 }
30 }
31}
32
33#[derive(Debug)]
35pub struct Encoder {
36 fragments: Vec<Vec<u8>>,
37 message_length: usize,
38 checksum: [u8; 4],
39 index: u32,
40}
41
42impl Encoder {
43 pub fn new(message: &[u8], max_fragment_length: usize) -> Self {
44 assert!(!message.is_empty());
45 assert!(max_fragment_length > 0);
46
47 let fragment_length = fragment_length(message.len(), max_fragment_length);
48
49 let fragments = partition(message.to_vec(), fragment_length);
50
51 Self {
52 fragments,
53 message_length: message.len(),
54 checksum: checksum(message),
55 index: 0,
56 }
57 }
58
59 pub fn next_fragment(&mut self) -> Fragment {
65 let index = self.index;
66
67 self.index += 1;
68
69 let indexes = choose_fragments(self.fragments.len(), self.checksum, index);
70
71 let mut data = vec![0; self.fragments[0].len()];
72
73 for item in indexes {
74 xor(&mut data, &self.fragments[item]);
75 }
76
77 Fragment {
78 meta: EncodingMetadata::new(self.fragments.len(), self.message_length, self.checksum),
79 index,
80 data,
81 }
82 }
83}
84
85pub const fn fragment_length(data_length: usize, max_fragment_length: usize) -> usize {
86 data_length.div_ceil(data_length.div_ceil(max_fragment_length))
87}
88
89pub fn partition(mut data: Vec<u8>, fragment_length: usize) -> Vec<Vec<u8>> {
90 let mut padding = vec![0; (fragment_length - (data.len() % fragment_length)) % fragment_length];
91
92 data.append(&mut padding);
93
94 data.chunks(fragment_length).map(<[u8]>::to_vec).collect()
95}
96
97fn choose_fragments(fragment_count: usize, checksum: [u8; 4], index: u32) -> Vec<usize> {
98 if (index as usize) < fragment_count {
99 return vec![index as usize];
100 }
101
102 let seed = (checksum, index).consensus_hash_sha256();
103
104 let mut rng = ChaCha20Rng::from_seed(seed.to_byte_array());
105
106 let degree = WeightedIndex::new((0..fragment_count).map(|x| 1.0 / (x + 1) as f64))
108 .unwrap()
109 .sample(&mut rng)
110 + 1;
111
112 (0..fragment_count).choose_multiple(&mut rng, degree)
114}
115
116fn xor(v1: &mut [u8], v2: &[u8]) {
117 assert_eq!(v1.len(), v2.len());
118
119 for (x1, &x2) in v1.iter_mut().zip(v2.iter()) {
120 *x1 ^= x2;
121 }
122}
123
124#[derive(Clone, Debug, PartialEq, Eq, Encodable, Decodable)]
126pub struct EncodingMetadata {
127 simple_fragments: u32,
128 message_length: u32,
129 checksum: [u8; 4],
130}
131
132impl EncodingMetadata {
133 pub fn new(simple_fragments: usize, message_length: usize, checksum: [u8; 4]) -> Self {
134 Self {
135 simple_fragments: simple_fragments as u32,
136 message_length: message_length as u32,
137 checksum,
138 }
139 }
140
141 pub fn fragment_length(&self) -> usize {
142 self.message_length().div_ceil(self.simple_fragments())
143 }
144
145 pub fn message_length(&self) -> usize {
146 self.message_length as usize
147 }
148
149 pub fn checksum(&self) -> [u8; 4] {
150 self.checksum
151 }
152
153 fn simple_fragments(&self) -> usize {
154 self.simple_fragments as usize
155 }
156
157 fn verify(&self) -> bool {
158 self.simple_fragments() > 0 && self.message_length() > 0 && self.fragment_length() > 0
159 }
160}
161
162#[derive(Clone, Debug, PartialEq, Eq, Encodable, Decodable)]
164pub struct Fragment {
165 meta: EncodingMetadata,
166 index: u32,
167 data: Vec<u8>,
168}
169
170impl Fragment {
171 pub fn indexes(&self) -> Vec<usize> {
173 choose_fragments(
174 self.meta.simple_fragments(),
175 self.meta.checksum(),
176 self.index,
177 )
178 }
179}
180
181#[derive(Default)]
184pub struct Decoder {
185 decoded: BTreeMap<usize, Vec<u8>>,
186 buffer: BTreeMap<Vec<usize>, Vec<u8>>,
187 meta: Option<EncodingMetadata>,
188}
189
190impl Decoder {
191 pub fn message(&self) -> Option<Vec<u8>> {
193 if self.decoded.len() < self.meta.as_ref()?.simple_fragments() {
194 return None;
195 }
196
197 let message = self
198 .decoded
199 .values()
200 .flat_map(|data| data.clone())
201 .take(self.meta.as_ref()?.message_length())
202 .collect();
203
204 Some(message)
205 }
206
207 pub fn receive(&mut self, fragment: Fragment) -> Result<Option<Vec<u8>>, Error> {
209 if let Some(message) = self.message() {
210 return Ok(Some(message));
211 }
212
213 if !fragment.meta.verify() {
214 return Err(Error::InvalidFragment);
215 }
216
217 if fragment.data.len() != fragment.meta.fragment_length() {
218 return Err(Error::InvalidFragment);
219 }
220
221 match self.meta.as_ref() {
222 None => {
223 self.meta = Some(fragment.meta.clone());
224 }
225 Some(meta) => {
226 if meta != &fragment.meta {
227 return Err(Error::InconsistentFragment);
228 }
229 }
230 }
231
232 if let [index] = fragment.indexes().as_slice() {
233 self.process_simple(*index, fragment.data.clone());
234 } else {
235 self.process_complex(fragment.indexes(), fragment.data.clone());
236 }
237
238 Ok(self.message())
239 }
240
241 fn process_simple(&mut self, index: usize, data: Vec<u8>) {
242 self.decoded.insert(index, data.clone());
243
244 let mut queue = self.decoded.clone().into_iter().collect::<Vec<_>>();
245
246 while let Some((index, simple)) = queue.pop() {
247 for (mut indexes, mut data) in self
248 .buffer
249 .clone()
250 .into_iter()
251 .filter(|entry| entry.0.contains(&index))
252 {
253 self.buffer.remove(&indexes).unwrap();
254
255 indexes.retain(|&i| i != index);
256
257 xor(&mut data, &simple);
258
259 if let [index] = indexes.as_slice() {
260 self.decoded.insert(*index, data.clone());
261 queue.push((*index, data));
262 } else {
263 self.buffer.insert(indexes, data);
264 }
265 }
266 }
267 }
268
269 fn process_complex(&mut self, mut indexes: Vec<usize>, mut data: Vec<u8>) {
270 let to_remove: Vec<usize> = indexes
271 .clone()
272 .into_iter()
273 .filter(|i| self.decoded.keys().any(|k| k == i))
274 .collect();
275
276 if indexes.len() == to_remove.len() {
277 return;
278 }
279
280 for remove in &to_remove {
281 xor(&mut data, self.decoded.get(remove).unwrap());
282 }
283
284 indexes.retain(|&i| !to_remove.contains(&i));
285
286 if let [index] = indexes.as_slice() {
287 self.decoded.insert(*index, data.clone());
288 } else {
289 self.buffer.insert(indexes, data);
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_fragment_length() {
300 assert_eq!(fragment_length(12345, 1955), 1764);
301 assert_eq!(fragment_length(12345, 30000), 12345);
302
303 assert_eq!(fragment_length(10, 4), 4);
304 assert_eq!(fragment_length(10, 5), 5);
305 assert_eq!(fragment_length(10, 6), 5);
306 assert_eq!(fragment_length(10, 10), 10);
307 }
308
309 #[test]
310 #[should_panic(expected = "assertion failed")]
311 fn test_fountain_encoder_zero_max_length() {
312 Encoder::new(b"foo", 0);
313 }
314
315 #[test]
316 #[should_panic(expected = "assertion failed")]
317 fn test_empty_encoder() {
318 Encoder::new(&[], 1);
319 }
320
321 #[test]
322 fn test_decoder_fragment_validation() {
323 let mut encoder1 = Encoder::new(b"foo", 2);
324 let mut encoder2 = Encoder::new(b"bar", 2);
325 let mut decoder = Decoder::default();
326
327 assert_eq!(decoder.receive(encoder1.next_fragment()).unwrap(), None);
329
330 assert!(matches!(
332 decoder.receive(encoder2.next_fragment()),
333 Err(Error::InconsistentFragment)
334 ));
335
336 assert_eq!(
338 decoder.receive(encoder1.next_fragment()).unwrap(),
339 Some(b"foo".to_vec())
340 );
341 }
342
343 #[test]
344 fn test_empty_decoder_empty_fragment() {
345 let mut decoder = Decoder::default();
346 let mut fragment = Fragment {
347 meta: EncodingMetadata::new(8, 100, [0x12, 0x34, 0x56, 0x78]),
348 index: 12,
349 data: vec![1, 5, 3, 3, 5],
350 };
351
352 fragment.meta.simple_fragments = 0;
354 assert!(matches!(
355 decoder.receive(fragment.clone()),
356 Err(Error::InvalidFragment)
357 ));
358 fragment.meta.simple_fragments = 8;
359
360 fragment.meta.message_length = 0;
362 assert!(matches!(
363 decoder.receive(fragment.clone()),
364 Err(Error::InvalidFragment)
365 ));
366 fragment.meta.message_length = 100;
367
368 fragment.data = vec![];
370 assert!(matches!(
371 decoder.receive(fragment.clone()),
372 Err(Error::InvalidFragment)
373 ));
374 }
375}