fedimint_core/encoding/
mod.rs

1//! Binary encoding interface suitable for
2//! consensus critical encoding.
3//!
4//! Over time all structs that ! need to be encoded to binary will be migrated
5//! to this interface.
6//!
7//! This code is based on corresponding `rust-bitcoin` types.
8//!
9//! See [`Encodable`] and [`Decodable`] for two main traits.
10
11pub mod as_base64;
12pub mod as_hex;
13mod bls12_381;
14pub mod btc;
15mod collections;
16mod iroh;
17mod secp256k1;
18mod threshold_crypto;
19
20use std::borrow::Cow;
21use std::cmp;
22use std::fmt::{Debug, Formatter};
23use std::io::{self, Error, Read, Write};
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25
26use anyhow::Context;
27use bitcoin::hashes::sha256;
28pub use fedimint_derive::{Decodable, Encodable};
29use hex::{FromHex, ToHex};
30use lightning::util::ser::BigSize;
31use serde::{Deserialize, Serialize};
32use thiserror::Error;
33
34use crate::core::ModuleInstanceId;
35use crate::module::registry::ModuleDecoderRegistry;
36use crate::util::SafeUrl;
37
38/// A writer counting number of bytes written to it
39///
40/// Copy&pasted from <https://github.com/SOF3/count-write> which
41/// uses Apache license (and it's a trivial amount of code, repeating
42/// on stack overflow).
43pub struct CountWrite<W> {
44    inner: W,
45    count: u64,
46}
47
48impl<W> CountWrite<W> {
49    /// Returns the number of bytes successfully written so far
50    pub fn count(&self) -> u64 {
51        self.count
52    }
53}
54
55impl<W> From<W> for CountWrite<W> {
56    fn from(inner: W) -> Self {
57        Self { inner, count: 0 }
58    }
59}
60
61impl<W: Write> io::Write for CountWrite<W> {
62    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
63        let written = self.inner.write(buf)?;
64        self.count += written as u64;
65        Ok(written)
66    }
67
68    fn flush(&mut self) -> io::Result<()> {
69        self.inner.flush()
70    }
71}
72
73/// Object-safe trait for things that can encode themselves
74///
75/// Like `rust-bitcoin`'s `consensus_encode`, but without generics,
76/// so can be used in `dyn` objects.
77pub trait DynEncodable {
78    fn consensus_encode_dyn(&self, writer: &mut dyn std::io::Write) -> Result<(), std::io::Error>;
79}
80
81impl Encodable for dyn DynEncodable {
82    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
83        self.consensus_encode_dyn(writer)
84    }
85}
86
87impl<T> DynEncodable for T
88where
89    T: Encodable,
90{
91    fn consensus_encode_dyn(
92        &self,
93        mut writer: &mut dyn std::io::Write,
94    ) -> Result<(), std::io::Error> {
95        <Self as Encodable>::consensus_encode(self, &mut writer)
96    }
97}
98
99impl Encodable for Box<dyn DynEncodable> {
100    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
101        (**self).consensus_encode_dyn(writer)
102    }
103}
104
105impl<T> Encodable for &T
106where
107    T: Encodable,
108{
109    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
110        (**self).consensus_encode(writer)
111    }
112}
113
114/// Data which can be encoded in a consensus-consistent way
115pub trait Encodable {
116    /// Encode an object with a well-defined format.
117    /// Returns the number of bytes written on success.
118    ///
119    /// The only errors returned are errors propagated from the writer.
120    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error>;
121
122    /// [`Self::consensus_encode`] to newly allocated `Vec<u8>`
123    fn consensus_encode_to_vec(&self) -> Vec<u8> {
124        let mut bytes = vec![];
125        self.consensus_encode(&mut bytes)
126            .expect("encoding to bytes can't fail for io reasons");
127        bytes
128    }
129
130    /// Encode and convert to hex string representation
131    fn consensus_encode_to_hex(&self) -> String {
132        // TODO: This double allocation offends real Rustaceans. We should
133        // be able to go straight to String, but this use case seems under-served
134        // by hex encoding crates.
135        self.consensus_encode_to_vec().encode_hex()
136    }
137
138    /// Encode without storing the encoding, return the size
139    fn consensus_encode_to_len(&self) -> u64 {
140        let mut writer = CountWrite::from(io::sink());
141        self.consensus_encode(&mut writer)
142            .expect("encoding to bytes can't fail for io reasons");
143
144        writer.count()
145    }
146
147    /// Generate a SHA256 hash of the consensus encoding using the default hash
148    /// engine for `H`.
149    ///
150    /// Can be used to validate all federation members agree on state without
151    /// revealing the object
152    fn consensus_hash<H>(&self) -> H
153    where
154        H: bitcoin::hashes::Hash,
155        H::Engine: std::io::Write,
156    {
157        let mut engine = H::engine();
158        self.consensus_encode(&mut engine)
159            .expect("writing to HashEngine cannot fail");
160        H::from_engine(engine)
161    }
162
163    /// [`Self::consensus_hash`] for [`bitcoin::hashes::sha256::Hash`]
164    fn consensus_hash_sha256(&self) -> sha256::Hash {
165        self.consensus_hash()
166    }
167}
168
169/// Maximum size, in bytes, of data we are allowed to ever decode
170/// for a single value.
171pub const MAX_DECODE_SIZE: usize = 16_000_000;
172
173/// Data which can be encoded in a consensus-consistent way
174pub trait Decodable: Sized {
175    /// Decode `Self` from a size-limited reader.
176    ///
177    /// Like `consensus_decode_partial` but relies on the reader being limited
178    /// in the amount of data it returns, e.g. by being wrapped in
179    /// [`std::io::Take`].
180    ///
181    /// Failing to abide to this requirement might lead to memory exhaustion
182    /// caused by malicious inputs.
183    ///
184    /// Users should default to `consensus_decode_partial`, but when data to be
185    /// decoded is already in a byte vector of a limited size, calling this
186    /// function directly might be marginally faster (due to avoiding extra
187    /// checks).
188    ///
189    /// ### Rules for trait implementations
190    ///
191    /// * Simple types that that have a fixed size (own and member fields),
192    ///   don't have to overwrite this method, or be concern with it, should
193    ///   only impl `consensus_decode_partial`.
194    /// * Types that deserialize based on decoded untrusted length should
195    ///   implement `consensus_decode_partial_from_finite_reader` only:
196    ///   * Default implementation of `consensus_decode_partial` will forward to
197    ///     `consensus_decode_partial_from_finite_reader` with the reader
198    ///     wrapped by `Take`, protecting from readers that keep returning data.
199    ///   * Implementation must make sure to put a cap on things like
200    ///     `Vec::with_capacity` and other allocations to avoid oversized
201    ///     allocations, and rely on the reader being finite and running out of
202    ///     data, and collections reallocating on a legitimately oversized input
203    ///     data, instead of trying to enforce arbitrary length limits.
204    /// * Types that contain other types that might be require limited reader
205    ///   (thus implementing `consensus_decode_partial_from_finite_reader`),
206    ///   should also implement it applying same rules, and in addition make
207    ///   sure to call `consensus_decode_partial_from_finite_reader` on all
208    ///   members, to avoid creating redundant `Take` wrappers
209    ///   (`Take<Take<...>>`). Failure to do so might result only in a tiny
210    ///   performance hit.
211    #[inline]
212    fn consensus_decode_partial_from_finite_reader<R: std::io::Read>(
213        r: &mut R,
214        modules: &ModuleDecoderRegistry,
215    ) -> Result<Self, DecodeError> {
216        // This method is always strictly less general than, `consensus_decode_partial`,
217        // so it's safe and make sense to default to just calling it. This way
218        // most types, that don't care about protecting against resource
219        // exhaustion due to malicious input, can just ignore it.
220        Self::consensus_decode_partial(r, modules)
221    }
222
223    #[inline]
224    fn consensus_decode_whole(
225        slice: &[u8],
226        modules: &ModuleDecoderRegistry,
227    ) -> Result<Self, DecodeError> {
228        let total_len = slice.len() as u64;
229
230        let r = &mut &slice[..];
231        let mut r = Read::take(r, total_len);
232
233        // This method is always strictly less general than, `consensus_decode_partial`,
234        // so it's safe and make sense to default to just calling it. This way
235        // most types, that don't care about protecting against resource
236        // exhaustion due to malicious input, can just ignore it.
237        let res = Self::consensus_decode_partial_from_finite_reader(&mut r, modules)?;
238        let left = r.limit();
239
240        if left != 0 {
241            return Err(fedimint_core::encoding::DecodeError::new_custom(
242                anyhow::anyhow!(
243                    "Type did not consume all bytes during decoding; expected={}; left={}; type={}",
244                    total_len,
245                    left,
246                    std::any::type_name::<Self>(),
247                ),
248            ));
249        }
250        Ok(res)
251    }
252    /// Decode an object with a well-defined format.
253    ///
254    /// This is the method that should be implemented for a typical, fixed sized
255    /// type implementing this trait. Default implementation is wrapping the
256    /// reader in [`std::io::Take`] to limit the input size to
257    /// [`MAX_DECODE_SIZE`], and forwards the call to
258    /// [`Self::consensus_decode_partial_from_finite_reader`], which is
259    /// convenient for types that override
260    /// [`Self::consensus_decode_partial_from_finite_reader`] instead.
261    #[inline]
262    fn consensus_decode_partial<R: std::io::Read>(
263        r: &mut R,
264        modules: &ModuleDecoderRegistry,
265    ) -> Result<Self, DecodeError> {
266        Self::consensus_decode_partial_from_finite_reader(
267            &mut r.take(MAX_DECODE_SIZE as u64),
268            modules,
269        )
270    }
271
272    /// Decode an object from hex
273    fn consensus_decode_hex(
274        hex: &str,
275        modules: &ModuleDecoderRegistry,
276    ) -> Result<Self, DecodeError> {
277        let bytes = Vec::<u8>::from_hex(hex)
278            .map_err(anyhow::Error::from)
279            .map_err(DecodeError::new_custom)?;
280        Decodable::consensus_decode_whole(&bytes, modules)
281    }
282}
283
284impl Encodable for SafeUrl {
285    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
286        self.to_string().consensus_encode(writer)
287    }
288}
289
290impl Decodable for SafeUrl {
291    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
292        d: &mut D,
293        modules: &ModuleDecoderRegistry,
294    ) -> Result<Self, DecodeError> {
295        String::consensus_decode_partial_from_finite_reader(d, modules)?
296            .parse::<Self>()
297            .map_err(DecodeError::from_err)
298    }
299}
300
301#[derive(Debug, Error)]
302pub struct DecodeError(pub(crate) anyhow::Error);
303
304impl DecodeError {
305    pub fn new_custom(e: anyhow::Error) -> Self {
306        Self(e)
307    }
308}
309
310impl From<anyhow::Error> for DecodeError {
311    fn from(e: anyhow::Error) -> Self {
312        Self(e)
313    }
314}
315
316macro_rules! impl_encode_decode_num_as_plain {
317    ($num_type:ty) => {
318        impl Encodable for $num_type {
319            fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
320                let bytes = self.to_be_bytes();
321                writer.write_all(&bytes[..])?;
322                Ok(())
323            }
324        }
325
326        impl Decodable for $num_type {
327            fn consensus_decode_partial<D: std::io::Read>(
328                d: &mut D,
329                _modules: &ModuleDecoderRegistry,
330            ) -> Result<Self, crate::encoding::DecodeError> {
331                let mut bytes = [0u8; (<$num_type>::BITS / 8) as usize];
332                d.read_exact(&mut bytes).map_err(DecodeError::from_err)?;
333                Ok(<$num_type>::from_be_bytes(bytes))
334            }
335        }
336    };
337}
338
339macro_rules! impl_encode_decode_num_as_bigsize {
340    ($num_type:ty) => {
341        impl Encodable for $num_type {
342            fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
343                BigSize(u64::from(*self)).consensus_encode(writer)
344            }
345        }
346
347        impl Decodable for $num_type {
348            fn consensus_decode_partial<D: std::io::Read>(
349                d: &mut D,
350                _modules: &ModuleDecoderRegistry,
351            ) -> Result<Self, crate::encoding::DecodeError> {
352                let varint = BigSize::consensus_decode_partial(d, &Default::default())
353                    .context(concat!("VarInt inside ", stringify!($num_type)))?;
354                <$num_type>::try_from(varint.0).map_err(crate::encoding::DecodeError::from_err)
355            }
356        }
357    };
358}
359
360impl_encode_decode_num_as_bigsize!(u64);
361impl_encode_decode_num_as_bigsize!(u32);
362impl_encode_decode_num_as_bigsize!(u16);
363impl_encode_decode_num_as_plain!(u8);
364
365macro_rules! impl_encode_decode_tuple {
366    ($($x:ident),*) => (
367        #[allow(non_snake_case)]
368        impl <$($x: Encodable),*> Encodable for ($($x),*) {
369            fn consensus_encode<W: std::io::Write>(&self, s: &mut W) -> Result<(), std::io::Error> {
370                let &($(ref $x),*) = self;
371                $($x.consensus_encode(s)?;)*
372                Ok(())
373            }
374        }
375
376        #[allow(non_snake_case)]
377        impl<$($x: Decodable),*> Decodable for ($($x),*) {
378            fn consensus_decode_partial<D: std::io::Read>(d: &mut D, modules: &ModuleDecoderRegistry) -> Result<Self, DecodeError> {
379                Ok(($({let $x = Decodable::consensus_decode_partial(d, modules)?; $x }),*))
380            }
381        }
382    );
383}
384
385impl_encode_decode_tuple!(T1, T2);
386impl_encode_decode_tuple!(T1, T2, T3);
387impl_encode_decode_tuple!(T1, T2, T3, T4);
388impl_encode_decode_tuple!(T1, T2, T3, T4, T5);
389impl_encode_decode_tuple!(T1, T2, T3, T4, T5, T6);
390
391impl<T> Encodable for Option<T>
392where
393    T: Encodable,
394{
395    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
396        if let Some(inner) = self {
397            1u8.consensus_encode(writer)?;
398            inner.consensus_encode(writer)?;
399        } else {
400            0u8.consensus_encode(writer)?;
401        }
402        Ok(())
403    }
404}
405
406impl<T> Decodable for Option<T>
407where
408    T: Decodable,
409{
410    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
411        d: &mut D,
412        modules: &ModuleDecoderRegistry,
413    ) -> Result<Self, DecodeError> {
414        let flag = u8::consensus_decode_partial_from_finite_reader(d, modules)?;
415        match flag {
416            0 => Ok(None),
417            1 => Ok(Some(T::consensus_decode_partial_from_finite_reader(
418                d, modules,
419            )?)),
420            _ => Err(DecodeError::from_str(
421                "Invalid flag for option enum, expected 0 or 1",
422            )),
423        }
424    }
425}
426
427impl<T, E> Encodable for Result<T, E>
428where
429    T: Encodable,
430    E: Encodable,
431{
432    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
433        match self {
434            Ok(value) => {
435                1u8.consensus_encode(writer)?;
436                value.consensus_encode(writer)?;
437            }
438            Err(error) => {
439                0u8.consensus_encode(writer)?;
440                error.consensus_encode(writer)?;
441            }
442        }
443
444        Ok(())
445    }
446}
447
448impl<T, E> Decodable for Result<T, E>
449where
450    T: Decodable,
451    E: Decodable,
452{
453    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
454        d: &mut D,
455        modules: &ModuleDecoderRegistry,
456    ) -> Result<Self, DecodeError> {
457        let flag = u8::consensus_decode_partial_from_finite_reader(d, modules)?;
458        match flag {
459            0 => Ok(Err(E::consensus_decode_partial_from_finite_reader(
460                d, modules,
461            )?)),
462            1 => Ok(Ok(T::consensus_decode_partial_from_finite_reader(
463                d, modules,
464            )?)),
465            _ => Err(DecodeError::from_str(
466                "Invalid flag for option enum, expected 0 or 1",
467            )),
468        }
469    }
470}
471
472impl<T> Encodable for Box<T>
473where
474    T: Encodable,
475{
476    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
477        self.as_ref().consensus_encode(writer)
478    }
479}
480
481impl<T> Decodable for Box<T>
482where
483    T: Decodable,
484{
485    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
486        d: &mut D,
487        modules: &ModuleDecoderRegistry,
488    ) -> Result<Self, DecodeError> {
489        Ok(Self::new(T::consensus_decode_partial_from_finite_reader(
490            d, modules,
491        )?))
492    }
493}
494
495impl Encodable for () {
496    fn consensus_encode<W: std::io::Write>(&self, _writer: &mut W) -> Result<(), std::io::Error> {
497        Ok(())
498    }
499}
500
501impl Decodable for () {
502    fn consensus_decode_partial<D: std::io::Read>(
503        _d: &mut D,
504        _modules: &ModuleDecoderRegistry,
505    ) -> Result<Self, DecodeError> {
506        Ok(())
507    }
508}
509
510impl Encodable for &str {
511    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
512        self.as_bytes().consensus_encode(writer)
513    }
514}
515
516impl Encodable for String {
517    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), Error> {
518        self.as_bytes().consensus_encode(writer)
519    }
520}
521
522impl Decodable for String {
523    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
524        d: &mut D,
525        modules: &ModuleDecoderRegistry,
526    ) -> Result<Self, DecodeError> {
527        Self::from_utf8(Decodable::consensus_decode_partial_from_finite_reader(
528            d, modules,
529        )?)
530        .map_err(DecodeError::from_err)
531    }
532}
533
534impl Encodable for SystemTime {
535    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
536        let duration = self.duration_since(UNIX_EPOCH).expect("valid duration");
537        duration.consensus_encode_dyn(writer)
538    }
539}
540
541impl Decodable for SystemTime {
542    fn consensus_decode_partial<D: std::io::Read>(
543        d: &mut D,
544        modules: &ModuleDecoderRegistry,
545    ) -> Result<Self, DecodeError> {
546        let duration = Duration::consensus_decode_partial(d, modules)?;
547        Ok(UNIX_EPOCH + duration)
548    }
549}
550
551impl Encodable for Duration {
552    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
553        self.as_secs().consensus_encode(writer)?;
554        self.subsec_nanos().consensus_encode(writer)?;
555
556        Ok(())
557    }
558}
559
560impl Decodable for Duration {
561    fn consensus_decode_partial<D: std::io::Read>(
562        d: &mut D,
563        modules: &ModuleDecoderRegistry,
564    ) -> Result<Self, DecodeError> {
565        let secs = Decodable::consensus_decode_partial(d, modules)?;
566        let nsecs = Decodable::consensus_decode_partial(d, modules)?;
567        Ok(Self::new(secs, nsecs))
568    }
569}
570
571impl Encodable for bool {
572    fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<(), Error> {
573        let bool_as_u8 = u8::from(*self);
574        writer.write_all(&[bool_as_u8])?;
575        Ok(())
576    }
577}
578
579impl Decodable for bool {
580    fn consensus_decode_partial<D: Read>(
581        d: &mut D,
582        _modules: &ModuleDecoderRegistry,
583    ) -> Result<Self, DecodeError> {
584        let mut bool_as_u8 = [0u8];
585        d.read_exact(&mut bool_as_u8)
586            .map_err(DecodeError::from_err)?;
587        match bool_as_u8[0] {
588            0 => Ok(false),
589            1 => Ok(true),
590            _ => Err(DecodeError::from_str("Out of range, expected 0 or 1")),
591        }
592    }
593}
594
595impl DecodeError {
596    // TODO: think about better name
597    #[allow(clippy::should_implement_trait)]
598    pub fn from_str(s: &'static str) -> Self {
599        #[derive(Debug)]
600        struct StrError(&'static str);
601
602        impl std::fmt::Display for StrError {
603            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
604                std::fmt::Display::fmt(&self.0, f)
605            }
606        }
607
608        impl std::error::Error for StrError {}
609
610        Self(anyhow::Error::from(StrError(s)))
611    }
612
613    pub fn from_err<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
614        Self(anyhow::Error::from(e))
615    }
616}
617
618impl std::fmt::Display for DecodeError {
619    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
620        f.write_fmt(format_args!("{:#}", self.0))
621    }
622}
623
624impl Encodable for Cow<'static, str> {
625    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
626        self.as_ref().consensus_encode(writer)
627    }
628}
629
630impl Decodable for Cow<'static, str> {
631    fn consensus_decode_partial<D: std::io::Read>(
632        d: &mut D,
633        modules: &ModuleDecoderRegistry,
634    ) -> Result<Self, DecodeError> {
635        Ok(Cow::Owned(String::consensus_decode_partial(d, modules)?))
636    }
637}
638
639/// A type that decodes `module_instance_id`-prefixed `T`s even
640/// when corresponding `Decoder` is not available.
641///
642/// All dyn-module types are encoded as:
643///
644/// ```norust
645/// module_instance_id | len_u64 | data
646/// ```
647///
648/// So clients that don't have a corresponding module, can read
649/// the `len_u64` and skip the amount of data specified in it.
650///
651/// This type makes it more convenient. It's possible to attempt
652/// to retry decoding after more modules become available by using
653/// [`DynRawFallback::redecode_raw`].
654///
655/// Notably this struct does not ignore any errors. It only skips
656/// decoding when the module decoder is not available.
657#[derive(Debug, Clone, Serialize, Deserialize)]
658pub enum DynRawFallback<T> {
659    Raw {
660        module_instance_id: ModuleInstanceId,
661        #[serde(with = "::fedimint_core::encoding::as_hex")]
662        raw: Vec<u8>,
663    },
664    Decoded(T),
665}
666
667impl<T> cmp::PartialEq for DynRawFallback<T>
668where
669    T: cmp::PartialEq + Encodable,
670{
671    fn eq(&self, other: &Self) -> bool {
672        match (self, other) {
673            (
674                Self::Raw {
675                    module_instance_id: mid_self,
676                    raw: raw_self,
677                },
678                Self::Raw {
679                    module_instance_id: mid_other,
680                    raw: raw_other,
681                },
682            ) => mid_self.eq(mid_other) && raw_self.eq(raw_other),
683            (r @ Self::Raw { .. }, d @ Self::Decoded(_))
684            | (d @ Self::Decoded(_), r @ Self::Raw { .. }) => {
685                r.consensus_encode_to_vec() == d.consensus_encode_to_vec()
686            }
687            (Self::Decoded(s), Self::Decoded(o)) => s == o,
688        }
689    }
690}
691
692impl<T> cmp::Eq for DynRawFallback<T> where T: cmp::Eq + Encodable {}
693
694impl<T> DynRawFallback<T>
695where
696    T: Decodable + 'static,
697{
698    /// Get the decoded `T` or `None` if not decoded yet
699    pub fn decoded(self) -> Option<T> {
700        match self {
701            Self::Raw { .. } => None,
702            Self::Decoded(v) => Some(v),
703        }
704    }
705
706    /// Convert into the decoded `T` and panic if not decoded yet
707    pub fn expect_decoded(self) -> T {
708        match self {
709            Self::Raw { .. } => {
710                panic!("Expected decoded value. Possibly `redecode_raw` call is missing.")
711            }
712            Self::Decoded(v) => v,
713        }
714    }
715
716    /// Get the decoded `T` and panic if not decoded yet
717    pub fn expect_decoded_ref(&self) -> &T {
718        match self {
719            Self::Raw { .. } => {
720                panic!("Expected decoded value. Possibly `redecode_raw` call is missing.")
721            }
722            Self::Decoded(v) => v,
723        }
724    }
725
726    /// Attempt to re-decode raw values with new set of of `modules`
727    ///
728    /// In certain contexts it might be necessary to try again with
729    /// a new set of modules.
730    pub fn redecode_raw(
731        self,
732        decoders: &ModuleDecoderRegistry,
733    ) -> Result<Self, crate::encoding::DecodeError> {
734        Ok(match self {
735            Self::Raw {
736                module_instance_id,
737                raw,
738            } => match decoders.get(module_instance_id) {
739                Some(decoder) => Self::Decoded(decoder.decode_complete(
740                    &mut &raw[..],
741                    raw.len() as u64,
742                    module_instance_id,
743                    decoders,
744                )?),
745                None => Self::Raw {
746                    module_instance_id,
747                    raw,
748                },
749            },
750            Self::Decoded(v) => Self::Decoded(v),
751        })
752    }
753}
754
755impl<T> From<T> for DynRawFallback<T> {
756    fn from(value: T) -> Self {
757        Self::Decoded(value)
758    }
759}
760
761impl<T> Decodable for DynRawFallback<T>
762where
763    T: Decodable + 'static,
764{
765    fn consensus_decode_partial_from_finite_reader<R: std::io::Read>(
766        reader: &mut R,
767        decoders: &ModuleDecoderRegistry,
768    ) -> Result<Self, crate::encoding::DecodeError> {
769        let module_instance_id =
770            fedimint_core::core::ModuleInstanceId::consensus_decode_partial_from_finite_reader(
771                reader, decoders,
772            )?;
773        Ok(match decoders.get(module_instance_id) {
774            Some(decoder) => {
775                let total_len_u64 =
776                    u64::consensus_decode_partial_from_finite_reader(reader, decoders)?;
777                Self::Decoded(decoder.decode_complete(
778                    reader,
779                    total_len_u64,
780                    module_instance_id,
781                    decoders,
782                )?)
783            }
784            None => {
785                // since the decoder is not available, just read the raw data
786                Self::Raw {
787                    module_instance_id,
788                    raw: Vec::consensus_decode_partial_from_finite_reader(reader, decoders)?,
789                }
790            }
791        })
792    }
793}
794
795impl<T> Encodable for DynRawFallback<T>
796where
797    T: Encodable,
798{
799    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
800        match self {
801            Self::Raw {
802                module_instance_id,
803                raw,
804            } => {
805                module_instance_id.consensus_encode(writer)?;
806                raw.consensus_encode(writer)?;
807                Ok(())
808            }
809            Self::Decoded(v) => v.consensus_encode(writer),
810        }
811    }
812}
813
814#[cfg(test)]
815mod tests {
816    use std::fmt::Debug;
817    use std::io::Cursor;
818
819    use super::*;
820    use crate::encoding::{Decodable, Encodable};
821    use crate::module::registry::ModuleRegistry;
822
823    pub(crate) fn test_roundtrip<T>(value: &T)
824    where
825        T: Encodable + Decodable + Eq + Debug,
826    {
827        let mut bytes = Vec::new();
828        value.consensus_encode(&mut bytes).unwrap();
829
830        let mut cursor = Cursor::new(bytes);
831        let decoded =
832            T::consensus_decode_partial(&mut cursor, &ModuleDecoderRegistry::default()).unwrap();
833        assert_eq!(value, &decoded);
834    }
835
836    pub(crate) fn test_roundtrip_expected<T>(value: &T, expected: &[u8])
837    where
838        T: Encodable + Decodable + Eq + Debug,
839    {
840        let mut bytes = Vec::new();
841        value.consensus_encode(&mut bytes).unwrap();
842        assert_eq!(&expected, &bytes);
843
844        let mut cursor = Cursor::new(bytes);
845        let decoded =
846            T::consensus_decode_partial(&mut cursor, &ModuleDecoderRegistry::default()).unwrap();
847        assert_eq!(value, &decoded);
848    }
849
850    #[derive(Debug, Eq, PartialEq, Encodable, Decodable)]
851    enum NoDefaultEnum {
852        Foo,
853        Bar(u32, String),
854        Baz { baz: u8 },
855    }
856
857    #[derive(Debug, Eq, PartialEq, Encodable, Decodable)]
858    enum DefaultEnum {
859        Foo,
860        Bar(u32, String),
861        #[encodable_default]
862        Default {
863            variant: u64,
864            bytes: Vec<u8>,
865        },
866    }
867
868    #[test_log::test]
869    fn test_derive_enum_no_default_roundtrip_success() {
870        let enums = [
871            NoDefaultEnum::Foo,
872            NoDefaultEnum::Bar(
873                42,
874                "The answer to life, the universe, and everything".to_string(),
875            ),
876            NoDefaultEnum::Baz { baz: 0 },
877        ];
878
879        for e in enums {
880            test_roundtrip(&e);
881        }
882    }
883
884    #[test_log::test]
885    fn test_derive_enum_no_default_decode_fail() {
886        let unknown_variant = DefaultEnum::Default {
887            variant: 42,
888            bytes: vec![0, 1, 2, 3],
889        };
890        let mut unknown_variant_encoding = vec![];
891        unknown_variant
892            .consensus_encode(&mut unknown_variant_encoding)
893            .unwrap();
894
895        let mut cursor = Cursor::new(&unknown_variant_encoding);
896        let decode_res =
897            NoDefaultEnum::consensus_decode_partial(&mut cursor, &ModuleRegistry::default());
898
899        match decode_res {
900            Ok(_) => panic!("Should return error"),
901            Err(e) => assert!(e.to_string().contains("Invalid enum variant")),
902        }
903    }
904
905    #[test_log::test]
906    fn test_derive_enum_default_decode_success() {
907        let unknown_variant = NoDefaultEnum::Baz { baz: 123 };
908        let mut unknown_variant_encoding = vec![];
909        unknown_variant
910            .consensus_encode(&mut unknown_variant_encoding)
911            .unwrap();
912
913        let mut cursor = Cursor::new(&unknown_variant_encoding);
914        let decode_res =
915            DefaultEnum::consensus_decode_partial(&mut cursor, &ModuleRegistry::default());
916
917        assert_eq!(
918            decode_res.unwrap(),
919            DefaultEnum::Default {
920                variant: 2,
921                bytes: vec![123],
922            }
923        );
924    }
925
926    #[test_log::test]
927    fn test_derive_struct() {
928        #[derive(Debug, Encodable, Decodable, Eq, PartialEq)]
929        struct TestStruct {
930            vec: Vec<u8>,
931            num: u32,
932        }
933
934        let reference = TestStruct {
935            vec: vec![1, 2, 3],
936            num: 42,
937        };
938        let bytes = [3, 1, 2, 3, 42];
939
940        test_roundtrip_expected(&reference, &bytes);
941    }
942
943    #[test_log::test]
944    fn test_derive_tuple_struct() {
945        #[derive(Debug, Encodable, Decodable, Eq, PartialEq)]
946        struct TestStruct(Vec<u8>, u32);
947
948        let reference = TestStruct(vec![1, 2, 3], 42);
949        let bytes = [3, 1, 2, 3, 42];
950
951        test_roundtrip_expected(&reference, &bytes);
952    }
953
954    #[test_log::test]
955    fn test_derive_enum() {
956        #[derive(Debug, Encodable, Decodable, Eq, PartialEq)]
957        enum TestEnum {
958            Foo(Option<u64>),
959            Bar { bazz: Vec<u8> },
960        }
961
962        let test_cases = [
963            (TestEnum::Foo(Some(42)), vec![0, 2, 1, 42]),
964            (TestEnum::Foo(None), vec![0, 1, 0]),
965            (
966                TestEnum::Bar {
967                    bazz: vec![1, 2, 3],
968                },
969                vec![1, 4, 3, 1, 2, 3],
970            ),
971        ];
972
973        for (reference, bytes) in test_cases {
974            test_roundtrip_expected(&reference, &bytes);
975        }
976    }
977
978    #[test_log::test]
979    fn test_systemtime() {
980        test_roundtrip(&fedimint_core::time::now());
981    }
982
983    #[test]
984    fn test_derive_empty_enum_decode() {
985        #[derive(Debug, Encodable, Decodable)]
986        enum NotConstructable {}
987
988        let vec = vec![42u8];
989        let mut cursor = Cursor::new(vec);
990
991        assert!(
992            NotConstructable::consensus_decode_partial(
993                &mut cursor,
994                &ModuleDecoderRegistry::default()
995            )
996            .is_err()
997        );
998    }
999
1000    #[test]
1001    fn test_custom_index_enum() {
1002        #[derive(Debug, PartialEq, Eq, Encodable, Decodable)]
1003        enum Old {
1004            Foo,
1005            Bar,
1006            Baz,
1007        }
1008
1009        #[derive(Debug, PartialEq, Eq, Encodable, Decodable)]
1010        enum New {
1011            #[encodable(index = 0)]
1012            Foo,
1013            #[encodable(index = 2)]
1014            Baz,
1015            #[encodable_default]
1016            Default { variant: u64, bytes: Vec<u8> },
1017        }
1018
1019        let test_vector = vec![
1020            (Old::Foo, New::Foo),
1021            (
1022                Old::Bar,
1023                New::Default {
1024                    variant: 1,
1025                    bytes: vec![],
1026                },
1027            ),
1028            (Old::Baz, New::Baz),
1029        ];
1030
1031        for (old, new) in test_vector {
1032            let old_bytes = old.consensus_encode_to_vec();
1033            let decoded_new = New::consensus_decode_whole(&old_bytes, &ModuleRegistry::default())
1034                .expect("Decoding failed");
1035            assert_eq!(decoded_new, new);
1036        }
1037    }
1038
1039    fn encode_value<T: Encodable>(value: &T) -> Vec<u8> {
1040        let mut writer = Vec::new();
1041        value.consensus_encode(&mut writer).unwrap();
1042        writer
1043    }
1044
1045    fn decode_value<T: Decodable>(bytes: &[u8]) -> T {
1046        T::consensus_decode_whole(bytes, &ModuleDecoderRegistry::default()).unwrap()
1047    }
1048
1049    fn keeps_ordering_after_serialization<T: Ord + Encodable + Decodable + Debug>(mut vec: Vec<T>) {
1050        vec.sort();
1051        let mut encoded = vec.iter().map(encode_value).collect::<Vec<_>>();
1052        encoded.sort();
1053        let decoded = encoded.iter().map(|v| decode_value(v)).collect::<Vec<_>>();
1054        for (i, (a, b)) in vec.iter().zip(decoded.iter()).enumerate() {
1055            assert_eq!(a, b, "difference at index {i}");
1056        }
1057    }
1058
1059    #[test]
1060    fn test_lexicographical_sorting() {
1061        #[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Encodable, Decodable)]
1062        struct TestAmount(u64);
1063
1064        #[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Encodable, Decodable)]
1065        struct TestComplexAmount(u16, u32, u64);
1066
1067        #[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Encodable, Decodable)]
1068        struct Text(String);
1069
1070        let amounts = (0..20000).map(TestAmount).collect::<Vec<_>>();
1071        keeps_ordering_after_serialization(amounts);
1072
1073        let complex_amounts = (10..20000)
1074            .flat_map(|i| {
1075                (i - 1..=i + 1).flat_map(move |j| {
1076                    (i - 1..=i + 1).map(move |k| TestComplexAmount(i as u16, j as u32, k as u64))
1077                })
1078            })
1079            .collect::<Vec<_>>();
1080        keeps_ordering_after_serialization(complex_amounts);
1081
1082        let texts = (' '..'~')
1083            .flat_map(|i| {
1084                (' '..'~')
1085                    .map(|j| Text(format!("{i}{j}")))
1086                    .collect::<Vec<_>>()
1087            })
1088            .collect::<Vec<_>>();
1089        keeps_ordering_after_serialization(texts);
1090
1091        // bitcoin structures are not lexicographically sortable so we cannot
1092        // test them here. in future we may crate a wrapper type that is
1093        // lexicographically sortable to use when needed
1094    }
1095}