1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use std::collections::BTreeMap;

use fedimint_core::Amount;
use serde::{Deserialize, Serialize};

use crate::encoding::{Decodable, DecodeError, Encodable};
use crate::module::registry::ModuleDecoderRegistry;

#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)]
pub struct InvalidAmountTierError(pub Amount);

impl std::fmt::Display for InvalidAmountTierError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Amount tier unknown to mint: {}", self.0)
    }
}

#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
#[serde(transparent)]
pub struct Tiered<T>(BTreeMap<Amount, T>);

impl<T> Default for Tiered<T> {
    fn default() -> Self {
        Self(Default::default())
    }
}

impl<T> Tiered<T> {
    /// Returns the highest tier amount
    pub fn max_tier(&self) -> &Amount {
        self.0.keys().max().expect("has tiers")
    }

    pub fn structural_eq<O>(&self, other: &Tiered<O>) -> bool {
        self.0.keys().eq(other.0.keys())
    }

    /// Returns a reference to the key of the specified tier
    pub fn tier(&self, amount: &Amount) -> Result<&T, InvalidAmountTierError> {
        self.0.get(amount).ok_or(InvalidAmountTierError(*amount))
    }

    pub fn count_tiers(&self) -> usize {
        self.0.len()
    }

    pub fn tiers(&self) -> impl DoubleEndedIterator<Item = &Amount> {
        self.0.keys()
    }

    pub fn iter(&self) -> impl Iterator<Item = (Amount, &T)> {
        self.0.iter().map(|(amt, key)| (*amt, key))
    }

    pub fn get(&self, amt: Amount) -> Option<&T> {
        self.0.get(&amt)
    }

    pub fn get_mut(&mut self, amt: Amount) -> Option<&mut T> {
        self.0.get_mut(&amt)
    }

    pub fn insert(&mut self, amt: Amount, v: T) -> Option<T> {
        self.0.insert(amt, v)
    }

    pub fn get_mut_or_default(&mut self, amt: Amount) -> &mut T
    where
        T: Default,
    {
        self.0.entry(amt).or_default()
    }

    pub fn entry(&mut self, amt: Amount) -> std::collections::btree_map::Entry<'_, Amount, T>
    where
        T: Default,
    {
        self.0.entry(amt)
    }

    pub fn as_map(&self) -> &BTreeMap<Amount, T> {
        &self.0
    }
}

impl Tiered<()> {
    /// Generates denominations of a given base up to and including `max`
    pub fn gen_denominations(denomination_base: u16, max: Amount) -> Tiered<()> {
        let mut amounts = vec![];

        let mut denomination = Amount::from_msats(1);
        while denomination <= max {
            amounts.push((denomination, ()));
            denomination = denomination * denomination_base.into();
        }

        amounts.into_iter().collect()
    }
}

impl<T> FromIterator<(Amount, T)> for Tiered<T> {
    fn from_iter<I: IntoIterator<Item = (Amount, T)>>(iter: I) -> Self {
        Tiered(iter.into_iter().collect())
    }
}

impl<C> Encodable for Tiered<C>
where
    C: Encodable,
{
    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
        self.0.consensus_encode(writer)
    }
}

impl<C> Decodable for Tiered<C>
where
    C: Decodable,
{
    fn consensus_decode<D: std::io::Read>(
        d: &mut D,
        modules: &ModuleDecoderRegistry,
    ) -> Result<Self, DecodeError> {
        Ok(Tiered(BTreeMap::consensus_decode(d, modules)?))
    }
}

#[cfg(test)]
mod tests {
    use fedimint_core::Amount;

    use super::Tiered;

    #[test]
    fn tier_generation_including_max_amount() {
        let max_amount = Amount::from_msats(16);
        let denominations = Tiered::gen_denominations(2, max_amount);

        // should produce [1, 2, 4, 8, 16]
        assert_eq!(denominations.tiers().count(), 5);
    }

    #[test]
    fn tier_generation_base_10() {
        let max_amount = Amount::from_msats(10000);
        let denominations = Tiered::gen_denominations(10, max_amount);

        // should produce [1, 10, 100, 1000, 10_000]
        assert_eq!(denominations.tiers().count(), 5);
    }
}