1use std::collections::BTreeMap;
2
3use anyhow::ensure;
4use fedimint_client::DynGlobalClientContext;
5use fedimint_client_module::module::OutPointRange;
6use fedimint_client_module::sm::{ClientSMDatabaseTransaction, State, StateTransition};
7use fedimint_core::PeerId;
8use fedimint_core::core::OperationId;
9use fedimint_core::db::IDatabaseTransactionOpsCoreTyped;
10use fedimint_core::encoding::{Decodable, Encodable};
11use fedimint_mintv2_common::Denomination;
12use tbs::{AggregatePublicKey, BlindedSignatureShare, PublicKeyShare, aggregate_signature_shares};
13
14use crate::api::MintV2ModuleApi;
15use crate::client_db::SpendableNoteKey;
16use crate::{MintClientContext, NoteIssuanceRequest};
17
18#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
19pub struct MintOutputStateMachine {
20 pub common: OutputSMCommon,
21 pub state: OutputSMState,
22}
23
24#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
25pub struct OutputSMCommon {
26 pub operation_id: OperationId,
27 pub range: Option<OutPointRange>,
28 pub issuance_requests: Vec<NoteIssuanceRequest>,
29}
30
31#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
32pub enum OutputSMState {
33 Pending,
35 Aborted,
38 Failure,
42 Success,
45}
46
47impl State for MintOutputStateMachine {
48 type ModuleContext = MintClientContext;
49
50 fn transitions(
51 &self,
52 context: &Self::ModuleContext,
53 global_context: &DynGlobalClientContext,
54 ) -> Vec<StateTransition<Self>> {
55 let context = context.clone();
56
57 match &self.state {
58 OutputSMState::Pending => {
59 vec![StateTransition::new(
60 Self::await_signature_shares(
61 global_context.clone(),
62 self.common.range,
63 self.common.issuance_requests.clone(),
64 context.tbs_pks.clone(),
65 ),
66 move |dbtx, signature_shares, old_state| {
67 let balance_update_sender = context.balance_update_sender.clone();
68
69 dbtx.module_tx()
70 .on_commit(move || balance_update_sender.send_replace(()));
71
72 Box::pin(Self::transition_outcome_ready(
73 dbtx,
74 signature_shares,
75 old_state,
76 context.tbs_agg_pks.clone(),
77 ))
78 },
79 )]
80 }
81 OutputSMState::Aborted | OutputSMState::Failure | OutputSMState::Success => {
82 vec![]
83 }
84 }
85 }
86
87 fn operation_id(&self) -> OperationId {
88 self.common.operation_id
89 }
90}
91
92impl MintOutputStateMachine {
93 async fn await_signature_shares(
94 global_context: DynGlobalClientContext,
95 range: Option<OutPointRange>,
96 issuance_requests: Vec<NoteIssuanceRequest>,
97 tbs_pks: BTreeMap<Denomination, BTreeMap<PeerId, PublicKeyShare>>,
98 ) -> Result<BTreeMap<PeerId, Vec<BlindedSignatureShare>>, String> {
99 if let Some(range) = range {
100 global_context.await_tx_accepted(range.txid).await?;
101
102 let shares = global_context
103 .module_api()
104 .fetch_signature_shares(range, issuance_requests, tbs_pks)
105 .await;
106
107 Ok(shares)
108 } else {
109 let shares = global_context
110 .module_api()
111 .fetch_signature_shares_recovery(issuance_requests, tbs_pks)
112 .await;
113
114 Ok(shares)
115 }
116 }
117
118 async fn transition_outcome_ready(
119 dbtx: &mut ClientSMDatabaseTransaction<'_, '_>,
120 signature_shares: Result<BTreeMap<PeerId, Vec<BlindedSignatureShare>>, String>,
121 old_state: MintOutputStateMachine,
122 tbs_pks: BTreeMap<Denomination, AggregatePublicKey>,
123 ) -> MintOutputStateMachine {
124 let Ok(signature_shares) = signature_shares else {
125 return MintOutputStateMachine {
126 common: old_state.common,
127 state: OutputSMState::Aborted,
128 };
129 };
130
131 for (i, request) in old_state.common.issuance_requests.iter().enumerate() {
132 let agg_blind_signature = aggregate_signature_shares(
133 &signature_shares
134 .iter()
135 .map(|(peer, shares)| (peer.to_usize() as u64, shares[i]))
136 .collect(),
137 );
138
139 let spendable_note = request.finalize(agg_blind_signature);
140
141 if !spendable_note.note().verify(
142 *tbs_pks
143 .get(&request.denomination)
144 .expect("No aggregated pk found for denomination"),
145 ) {
146 return MintOutputStateMachine {
147 common: old_state.common,
148 state: OutputSMState::Failure,
149 };
150 }
151
152 dbtx.module_tx()
153 .insert_new_entry(&SpendableNoteKey(spendable_note), &())
154 .await;
155 }
156
157 MintOutputStateMachine {
158 common: old_state.common,
159 state: OutputSMState::Success,
160 }
161 }
162}
163
164pub fn verify_blind_shares(
165 peer: PeerId,
166 signature_shares: Vec<BlindedSignatureShare>,
167 issuance_requests: &[NoteIssuanceRequest],
168 tbs_pks: &BTreeMap<Denomination, BTreeMap<PeerId, PublicKeyShare>>,
169) -> anyhow::Result<Vec<BlindedSignatureShare>> {
170 ensure!(
171 signature_shares.len() == issuance_requests.len(),
172 "Invalid number of signatures shares"
173 );
174
175 for (request, share) in issuance_requests.iter().zip(signature_shares.iter()) {
176 let amount_key = tbs_pks
177 .get(&request.denomination)
178 .expect("No pk shares found for denomination")
179 .get(&peer)
180 .expect("No pk share found for peer");
181
182 ensure!(
183 tbs::verify_signature_share(request.blinded_message(), *share, *amount_key),
184 "Invalid blind signature"
185 );
186 }
187
188 Ok(signature_shares)
189}