fedimint_core/
tiered_multi.rs1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3
4use fedimint_core::encoding::{Decodable, DecodeError, Encodable};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::module::registry::ModuleDecoderRegistry;
9use crate::{Amount, Tiered};
10
11#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
18pub struct TieredMulti<T>(Tiered<Vec<T>>);
19
20impl<T> TieredMulti<T> {
21 pub fn new(map: BTreeMap<Amount, Vec<T>>) -> Self {
23 Self(map.into_iter().filter(|(_, v)| !v.is_empty()).collect())
24 }
25
26 pub fn new_aggregate_from_tiered_iter(tiered_iter: impl Iterator<Item = Tiered<T>>) -> Self {
30 let mut tiered_multi = Self::default();
31
32 for tiered in tiered_iter {
33 for (amt, val) in tiered {
34 tiered_multi.push(amt, val);
35 }
36 }
37
38 assert!(
42 tiered_multi
43 .summary()
44 .iter()
45 .map(|(_tier, count)| count)
46 .all_equal(),
47 "The supplied Tiered structs were not structurally equal"
48 );
49
50 tiered_multi
51 }
52
53 pub fn total_amount(&self) -> Amount {
55 let milli_sat = self
56 .0
57 .iter()
58 .map(|(tier, notes)| tier.msats * (notes.len() as u64))
59 .sum();
60 Amount::from_msats(milli_sat)
61 }
62
63 pub fn count_items(&self) -> usize {
65 self.0.values().map(Vec::len).sum()
66 }
67
68 pub fn count_tiers(&self) -> usize {
70 self.0.count_tiers()
71 }
72
73 pub fn summary(&self) -> TieredCounts {
75 TieredCounts(
76 self.iter()
77 .map(|(amount, values)| (amount, values.len()))
78 .collect(),
79 )
80 }
81
82 pub fn is_empty(&self) -> bool {
84 self.assert_invariants();
85 self.count_items() == 0
86 }
87
88 pub fn iter(&self) -> impl Iterator<Item = (Amount, &Vec<T>)> {
90 self.0.iter()
91 }
92
93 pub fn iter_items(&self) -> impl DoubleEndedIterator<Item = (Amount, &T)> {
99 self.0
102 .iter()
103 .flat_map(|(amt, notes)| notes.iter().map(move |c| (amt, c)))
104 }
105
106 pub fn into_iter_items(self) -> impl DoubleEndedIterator<Item = (Amount, T)> {
112 self.0
115 .into_iter()
116 .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c)))
117 }
118
119 pub fn push(&mut self, amt: Amount, val: T) {
120 self.0.entry(amt).or_default().push(val);
121 }
122
123 fn assert_invariants(&self) {
124 #[cfg(debug_assertions)]
126 self.iter().for_each(|(_, v)| debug_assert!(!v.is_empty()));
127 }
128}
129
130impl<C> FromIterator<(Amount, C)> for TieredMulti<C> {
131 fn from_iter<T: IntoIterator<Item = (Amount, C)>>(iter: T) -> Self {
132 let mut res = Self::default();
133 res.extend(iter);
134 res.assert_invariants();
135 res
136 }
137}
138
139impl<C> IntoIterator for TieredMulti<C>
140where
141 C: 'static + Send,
142{
143 type Item = (Amount, Vec<C>);
144 type IntoIter = std::collections::btree_map::IntoIter<Amount, Vec<C>>;
145
146 fn into_iter(self) -> Self::IntoIter {
147 self.0.into_iter()
148 }
149}
150
151impl<C> Default for TieredMulti<C> {
152 fn default() -> Self {
153 Self(Tiered::default())
154 }
155}
156
157impl<C> Extend<(Amount, C)> for TieredMulti<C> {
158 fn extend<T: IntoIterator<Item = (Amount, C)>>(&mut self, iter: T) {
159 for (amount, note) in iter {
160 self.0.entry(amount).or_default().push(note);
161 }
162 }
163}
164
165impl<C> Encodable for TieredMulti<C>
166where
167 C: Encodable + 'static,
168{
169 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
170 self.0.consensus_encode(writer)
171 }
172}
173
174impl<C> Decodable for TieredMulti<C>
175where
176 C: Decodable + 'static,
177{
178 fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
179 d: &mut D,
180 modules: &ModuleDecoderRegistry,
181 ) -> Result<Self, DecodeError> {
182 Ok(Self(Tiered::consensus_decode_partial_from_finite_reader(
183 d, modules,
184 )?))
185 }
186}
187
188#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize, Clone)]
189pub struct TieredCounts(Tiered<usize>);
190
191impl TieredCounts {
192 pub fn inc(&mut self, tier: Amount, n: usize) {
193 if 0 < n {
194 *self.0.get_mut_or_default(tier) += n;
195 }
196 }
197
198 pub fn dec(&mut self, tier: Amount) {
199 match self.0.entry(tier) {
200 Entry::Vacant(_) => panic!("Trying to decrement an empty tier"),
201 Entry::Occupied(mut c) => {
202 assert!(*c.get() != 0);
203 if *c.get() == 1 {
204 c.remove_entry();
205 } else {
206 *c.get_mut() -= 1;
207 }
208 }
209 }
210 self.assert_invariants();
211 }
212
213 pub fn iter(&self) -> impl Iterator<Item = (Amount, usize)> + '_ {
214 self.0.iter().map(|(k, v)| (k, *v))
215 }
216
217 pub fn total_amount(&self) -> Amount {
218 self.0.iter().map(|(k, v)| k * (*v as u64)).sum::<Amount>()
219 }
220
221 pub fn count_items(&self) -> usize {
222 self.0.iter().map(|(_, v)| *v).sum()
223 }
224
225 pub fn count_tiers(&self) -> usize {
226 self.0.count_tiers()
227 }
228
229 pub fn is_empty(&self) -> bool {
230 self.count_items() == 0
231 }
232
233 pub fn get(&self, tier: Amount) -> usize {
234 self.assert_invariants();
235 self.0.get(tier).copied().unwrap_or_default()
236 }
237
238 fn assert_invariants(&self) {
239 #[cfg(debug_assertions)]
241 self.iter().for_each(|(_, count)| debug_assert!(0 < count));
242 }
243}
244
245impl FromIterator<(Amount, usize)> for TieredCounts {
246 fn from_iter<I: IntoIterator<Item = (Amount, usize)>>(iter: I) -> Self {
247 Self(iter.into_iter().filter(|(_, count)| *count != 0).collect())
248 }
249}
250
251#[cfg(test)]
252mod test;