fedimint_core/
tiered_multi.rs

1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3
4use fedimint_core::encoding::{Decodable, DecodeError, Encodable};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::module::registry::ModuleDecoderRegistry;
9use crate::{Amount, Tiered};
10
11/// Represents notes of different denominations.
12///
13/// **Attention:** care has to be taken when constructing this to avoid overflow
14/// when calculating the total amount represented. As it is prudent to limit
15/// both the maximum note amount and maximum note count per transaction this
16/// shouldn't be a problem in practice though.
17#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
18pub struct TieredMulti<T>(Tiered<Vec<T>>);
19
20impl<T> TieredMulti<T> {
21    /// Returns a new `TieredMulti` with the given `BTreeMap` map
22    pub fn new(map: BTreeMap<Amount, Vec<T>>) -> Self {
23        Self(map.into_iter().filter(|(_, v)| !v.is_empty()).collect())
24    }
25
26    /// Returns a new `TieredMulti` from a collection of `Tiered` structs.
27    /// The `Tiered` structs are expected to be structurally equal, otherwise
28    /// this function will panic.
29    pub fn new_aggregate_from_tiered_iter(tiered_iter: impl Iterator<Item = Tiered<T>>) -> Self {
30        let mut tiered_multi = Self::default();
31
32        for tiered in tiered_iter {
33            for (amt, val) in tiered {
34                tiered_multi.push(amt, val);
35            }
36        }
37
38        // TODO: This only asserts that the output is structurally sound, not the input.
39        // For example, an input with tier `Amount`s of [[1, 2], [4, 8]] would currently
40        // be accepted even though it is not structurally sound.
41        assert!(
42            tiered_multi
43                .summary()
44                .iter()
45                .map(|(_tier, count)| count)
46                .all_equal(),
47            "The supplied Tiered structs were not structurally equal"
48        );
49
50        tiered_multi
51    }
52
53    /// Returns the total value of all notes in msat as `Amount`
54    pub fn total_amount(&self) -> Amount {
55        let milli_sat = self
56            .0
57            .iter()
58            .map(|(tier, notes)| tier.msats * (notes.len() as u64))
59            .sum();
60        Amount::from_msats(milli_sat)
61    }
62
63    /// Returns the number of items in all vectors
64    pub fn count_items(&self) -> usize {
65        self.0.values().map(Vec::len).sum()
66    }
67
68    /// Returns the number of tiers
69    pub fn count_tiers(&self) -> usize {
70        self.0.count_tiers()
71    }
72
73    /// Returns the summary of number of items in each tier
74    pub fn summary(&self) -> TieredCounts {
75        TieredCounts(
76            self.iter()
77                .map(|(amount, values)| (amount, values.len()))
78                .collect(),
79        )
80    }
81
82    /// Verifies whether all vectors in all tiers are empty
83    pub fn is_empty(&self) -> bool {
84        self.assert_invariants();
85        self.count_items() == 0
86    }
87
88    /// Returns an borrowing iterator
89    pub fn iter(&self) -> impl Iterator<Item = (Amount, &Vec<T>)> {
90        self.0.iter()
91    }
92
93    /// Returns an iterator over every `(Amount, &T)`
94    ///
95    /// Note: The order of the elements is important:
96    /// from the lowest tier to the highest, then in order of elements in the
97    /// Vec
98    pub fn iter_items(&self) -> impl DoubleEndedIterator<Item = (Amount, &T)> {
99        // Note: If you change the method implementation, make sure that the returned
100        // order of the elements stays consistent.
101        self.0
102            .iter()
103            .flat_map(|(amt, notes)| notes.iter().map(move |c| (amt, c)))
104    }
105
106    /// Returns an consuming iterator over every `(Amount, T)`
107    ///
108    /// Note: The order of the elements is important:
109    /// from the lowest tier to the highest, then in order of elements in the
110    /// Vec
111    pub fn into_iter_items(self) -> impl DoubleEndedIterator<Item = (Amount, T)> {
112        // Note: If you change the method implementation, make sure that the returned
113        // order of the elements stays consistent.
114        self.0
115            .into_iter()
116            .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c)))
117    }
118
119    pub fn push(&mut self, amt: Amount, val: T) {
120        self.0.entry(amt).or_default().push(val);
121    }
122
123    fn assert_invariants(&self) {
124        // Just for compactness and determinism, we don't want entries with 0 items
125        #[cfg(debug_assertions)]
126        self.iter().for_each(|(_, v)| debug_assert!(!v.is_empty()));
127    }
128}
129
130impl<C> FromIterator<(Amount, C)> for TieredMulti<C> {
131    fn from_iter<T: IntoIterator<Item = (Amount, C)>>(iter: T) -> Self {
132        let mut res = Self::default();
133        res.extend(iter);
134        res.assert_invariants();
135        res
136    }
137}
138
139impl<C> IntoIterator for TieredMulti<C>
140where
141    C: 'static + Send,
142{
143    type Item = (Amount, Vec<C>);
144    type IntoIter = std::collections::btree_map::IntoIter<Amount, Vec<C>>;
145
146    fn into_iter(self) -> Self::IntoIter {
147        self.0.into_iter()
148    }
149}
150
151impl<C> Default for TieredMulti<C> {
152    fn default() -> Self {
153        Self(Tiered::default())
154    }
155}
156
157impl<C> Extend<(Amount, C)> for TieredMulti<C> {
158    fn extend<T: IntoIterator<Item = (Amount, C)>>(&mut self, iter: T) {
159        for (amount, note) in iter {
160            self.0.entry(amount).or_default().push(note);
161        }
162    }
163}
164
165impl<C> Encodable for TieredMulti<C>
166where
167    C: Encodable + 'static,
168{
169    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
170        self.0.consensus_encode(writer)
171    }
172}
173
174impl<C> Decodable for TieredMulti<C>
175where
176    C: Decodable + 'static,
177{
178    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
179        d: &mut D,
180        modules: &ModuleDecoderRegistry,
181    ) -> Result<Self, DecodeError> {
182        Ok(Self(Tiered::consensus_decode_partial_from_finite_reader(
183            d, modules,
184        )?))
185    }
186}
187
188#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize, Clone)]
189pub struct TieredCounts(Tiered<usize>);
190
191impl TieredCounts {
192    pub fn inc(&mut self, tier: Amount, n: usize) {
193        if 0 < n {
194            *self.0.get_mut_or_default(tier) += n;
195        }
196    }
197
198    pub fn dec(&mut self, tier: Amount) {
199        match self.0.entry(tier) {
200            Entry::Vacant(_) => panic!("Trying to decrement an empty tier"),
201            Entry::Occupied(mut c) => {
202                assert!(*c.get() != 0);
203                if *c.get() == 1 {
204                    c.remove_entry();
205                } else {
206                    *c.get_mut() -= 1;
207                }
208            }
209        }
210        self.assert_invariants();
211    }
212
213    pub fn iter(&self) -> impl Iterator<Item = (Amount, usize)> + '_ {
214        self.0.iter().map(|(k, v)| (k, *v))
215    }
216
217    pub fn total_amount(&self) -> Amount {
218        self.0.iter().map(|(k, v)| k * (*v as u64)).sum::<Amount>()
219    }
220
221    pub fn count_items(&self) -> usize {
222        self.0.iter().map(|(_, v)| *v).sum()
223    }
224
225    pub fn count_tiers(&self) -> usize {
226        self.0.count_tiers()
227    }
228
229    pub fn is_empty(&self) -> bool {
230        self.count_items() == 0
231    }
232
233    pub fn get(&self, tier: Amount) -> usize {
234        self.assert_invariants();
235        self.0.get(tier).copied().unwrap_or_default()
236    }
237
238    fn assert_invariants(&self) {
239        // Just for compactness and determinism, we don't want entries with 0 count
240        #[cfg(debug_assertions)]
241        self.iter().for_each(|(_, count)| debug_assert!(0 < count));
242    }
243}
244
245impl FromIterator<(Amount, usize)> for TieredCounts {
246    fn from_iter<I: IntoIterator<Item = (Amount, usize)>>(iter: I) -> Self {
247        Self(iter.into_iter().filter(|(_, count)| *count != 0).collect())
248    }
249}
250
251#[cfg(test)]
252mod test;