fedimint_core/encoding/
collections.rs

1use std::any::TypeId;
2use std::collections::{BTreeMap, BTreeSet, VecDeque};
3use std::fmt::Debug;
4
5use crate::module::registry::ModuleRegistry;
6use crate::{Decodable, DecodeError, Encodable, ModuleDecoderRegistry};
7
8impl<T> Encodable for &[T]
9where
10    T: Encodable + 'static,
11{
12    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
13        if TypeId::of::<T>() == TypeId::of::<u8>() {
14            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
15            let bytes = unsafe { std::mem::transmute::<&[T], &[u8]>(self) };
16
17            (bytes.len() as u64).consensus_encode(writer)?;
18            writer.write_all(bytes)?;
19            return Ok(());
20        }
21
22        (self.len() as u64).consensus_encode(writer)?;
23
24        for item in *self {
25            item.consensus_encode(writer)?;
26        }
27        Ok(())
28    }
29}
30
31impl<T> Encodable for Vec<T>
32where
33    T: Encodable + 'static,
34{
35    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
36        (self as &[T]).consensus_encode(writer)
37    }
38}
39
40impl<T> Decodable for Vec<T>
41where
42    T: Decodable + 'static,
43{
44    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
45        d: &mut D,
46        modules: &ModuleDecoderRegistry,
47    ) -> Result<Self, DecodeError> {
48        const CHUNK_SIZE: usize = 64 * 1024;
49
50        if TypeId::of::<T>() == TypeId::of::<u8>() {
51            let len =
52                u64::consensus_decode_partial_from_finite_reader(d, &ModuleRegistry::default())?;
53
54            let mut len: usize =
55                usize::try_from(len).map_err(|_| DecodeError::from_str("size exceeds memory"))?;
56
57            let mut bytes = vec![];
58
59            // Adapted from <https://github.com/rust-bitcoin/rust-bitcoin/blob/e2b9555070d9357fb552e56085fb6fb3f0274560/bitcoin/src/consensus/encode.rs#L667-L674>
60            while len > 0 {
61                let chunk_start = bytes.len();
62                let chunk_size = core::cmp::min(len, CHUNK_SIZE);
63                let chunk_end = chunk_start + chunk_size;
64                bytes.resize(chunk_end, 0u8);
65                d.read_exact(&mut bytes[chunk_start..chunk_end])
66                    .map_err(DecodeError::from_err)?;
67                len -= chunk_size;
68            }
69
70            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
71            return Ok(unsafe { std::mem::transmute::<Vec<u8>, Self>(bytes) });
72        }
73        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
74
75        // `collect` under the hood uses `FromIter::from_iter`, which can potentially be
76        // backed by code like:
77        // <https://github.com/rust-lang/rust/blob/fe03b46ee4688a99d7155b4f9dcd875b6903952d/library/alloc/src/vec/spec_from_iter_nested.rs#L31>
78        // This can take `size_hint` from input iterator and pre-allocate memory
79        // upfront with `Vec::with_capacity`. Because of that untrusted `len`
80        // should not be used directly.
81        let cap_len = std::cmp::min(8_000 / std::mem::size_of::<T>() as u64, len);
82
83        // Up to a cap, use the (potentially specialized for better perf in stdlib)
84        // `from_iter`.
85        let mut v: Self = (0..cap_len)
86            .map(|_| T::consensus_decode_partial_from_finite_reader(d, modules))
87            .collect::<Result<Self, DecodeError>>()?;
88
89        // Add any excess manually avoiding any surprises.
90        while (v.len() as u64) < len {
91            v.push(T::consensus_decode_partial_from_finite_reader(d, modules)?);
92        }
93
94        assert_eq!(v.len() as u64, len);
95
96        Ok(v)
97    }
98}
99
100impl<T> Encodable for VecDeque<T>
101where
102    T: Encodable + 'static,
103{
104    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
105        (self.len() as u64).consensus_encode(writer)?;
106        for i in self {
107            i.consensus_encode(writer)?;
108        }
109        Ok(())
110    }
111}
112
113impl<T> Decodable for VecDeque<T>
114where
115    T: Decodable + 'static,
116{
117    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
118        d: &mut D,
119        modules: &ModuleDecoderRegistry,
120    ) -> Result<Self, DecodeError> {
121        Ok(Self::from(
122            Vec::<T>::consensus_decode_partial_from_finite_reader(d, modules)?,
123        ))
124    }
125}
126
127impl<T, const SIZE: usize> Encodable for [T; SIZE]
128where
129    T: Encodable + 'static,
130{
131    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
132        if TypeId::of::<T>() == TypeId::of::<u8>() {
133            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
134            let bytes = unsafe { std::mem::transmute::<&[T; SIZE], &[u8; SIZE]>(self) };
135            writer.write_all(bytes)?;
136            return Ok(());
137        }
138
139        for item in self {
140            item.consensus_encode(writer)?;
141        }
142        Ok(())
143    }
144}
145
146impl<T, const SIZE: usize> Decodable for [T; SIZE]
147where
148    T: Decodable + Debug + Default + Copy + 'static,
149{
150    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
151        d: &mut D,
152        modules: &ModuleDecoderRegistry,
153    ) -> Result<Self, DecodeError> {
154        // From <https://github.com/rust-lang/rust/issues/61956>
155        unsafe fn horribe_array_transmute_workaround<const N: usize, A, B>(
156            mut arr: [A; N],
157        ) -> [B; N] {
158            let ptr = std::ptr::from_mut(&mut arr).cast::<[B; N]>();
159            let res = unsafe { ptr.read() };
160            core::mem::forget(arr);
161            res
162        }
163
164        if TypeId::of::<T>() == TypeId::of::<u8>() {
165            let mut bytes = [0u8; SIZE];
166            d.read_exact(bytes.as_mut_slice())
167                .map_err(DecodeError::from_err)?;
168
169            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
170            return Ok(unsafe { horribe_array_transmute_workaround(bytes) });
171        }
172
173        // todo: impl without copy
174        let mut data = [T::default(); SIZE];
175        for item in &mut data {
176            *item = T::consensus_decode_partial_from_finite_reader(d, modules)?;
177        }
178        Ok(data)
179    }
180}
181
182impl<K, V> Encodable for BTreeMap<K, V>
183where
184    K: Encodable,
185    V: Encodable,
186{
187    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
188        (self.len() as u64).consensus_encode(writer)?;
189        for (k, v) in self {
190            k.consensus_encode(writer)?;
191            v.consensus_encode(writer)?;
192        }
193        Ok(())
194    }
195}
196
197impl<K, V> Decodable for BTreeMap<K, V>
198where
199    K: Decodable + Ord,
200    V: Decodable,
201{
202    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
203        d: &mut D,
204        modules: &ModuleDecoderRegistry,
205    ) -> Result<Self, DecodeError> {
206        let mut res = Self::new();
207        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
208        for _ in 0..len {
209            let k = K::consensus_decode_partial_from_finite_reader(d, modules)?;
210            if res
211                .last_key_value()
212                .is_some_and(|(prev_key, _v)| k <= *prev_key)
213            {
214                return Err(DecodeError::from_str("Non-canonical encoding"));
215            }
216            let v = V::consensus_decode_partial_from_finite_reader(d, modules)?;
217            if res.insert(k, v).is_some() {
218                return Err(DecodeError(anyhow::format_err!("Duplicate key")));
219            }
220        }
221        Ok(res)
222    }
223}
224
225impl<K> Encodable for BTreeSet<K>
226where
227    K: Encodable,
228{
229    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
230        (self.len() as u64).consensus_encode(writer)?;
231        for k in self {
232            k.consensus_encode(writer)?;
233        }
234        Ok(())
235    }
236}
237
238impl<K> Decodable for BTreeSet<K>
239where
240    K: Decodable + Ord,
241{
242    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
243        d: &mut D,
244        modules: &ModuleDecoderRegistry,
245    ) -> Result<Self, DecodeError> {
246        let mut res = Self::new();
247        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
248        for _ in 0..len {
249            let k = K::consensus_decode_partial_from_finite_reader(d, modules)?;
250            if res.last().is_some_and(|prev_key| k <= *prev_key) {
251                return Err(DecodeError::from_str("Non-canonical encoding"));
252            }
253            if !res.insert(k) {
254                return Err(DecodeError(anyhow::format_err!("Duplicate key")));
255            }
256        }
257        Ok(res)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::encoding::tests::test_roundtrip_expected;
265
266    #[test_log::test]
267    fn test_lists() {
268        // The length of the list is encoded before the elements. It is encoded as a
269        // variable length integer, but for lists with a length less than 253, it's
270        // encoded as a single byte.
271        test_roundtrip_expected(&vec![1u8, 2, 3], &[3u8, 1, 2, 3]);
272        test_roundtrip_expected(&vec![1u16, 2, 3], &[3u8, 1, 2, 3]);
273        test_roundtrip_expected(&vec![1u32, 2, 3], &[3u8, 1, 2, 3]);
274        test_roundtrip_expected(&vec![1u64, 2, 3], &[3u8, 1, 2, 3]);
275
276        // Empty list should be encoded as a single byte 0.
277        test_roundtrip_expected::<Vec<u8>>(&vec![], &[0u8]);
278        test_roundtrip_expected::<Vec<u16>>(&vec![], &[0u8]);
279        test_roundtrip_expected::<Vec<u32>>(&vec![], &[0u8]);
280        test_roundtrip_expected::<Vec<u64>>(&vec![], &[0u8]);
281
282        // A length prefix greater than the number of elements should return an error.
283        let buf = [4u8, 1, 2, 3];
284        assert!(Vec::<u8>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
285        assert!(Vec::<u16>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
286        assert!(VecDeque::<u8>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
287        assert!(VecDeque::<u16>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
288
289        // A length prefix less than the number of elements should skip elements beyond
290        // the encoded length.
291        let buf = [2u8, 1, 2, 3];
292        assert_eq!(
293            Vec::<u8>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default()).unwrap(),
294            vec![1u8, 2]
295        );
296        assert_eq!(
297            Vec::<u16>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
298                .unwrap(),
299            vec![1u16, 2]
300        );
301        assert_eq!(
302            VecDeque::<u8>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
303                .unwrap(),
304            vec![1u8, 2]
305        );
306        assert_eq!(
307            VecDeque::<u16>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
308                .unwrap(),
309            vec![1u16, 2]
310        );
311    }
312
313    #[test_log::test]
314    fn test_btreemap() {
315        test_roundtrip_expected(
316            &BTreeMap::from([("a".to_string(), 1u32), ("b".to_string(), 2)]),
317            &[2, 1, 97, 1, 1, 98, 2],
318        );
319    }
320
321    #[test_log::test]
322    fn test_btreeset() {
323        test_roundtrip_expected(
324            &BTreeSet::from(["a".to_string(), "b".to_string()]),
325            &[2, 1, 97, 1, 98],
326        );
327    }
328}