fedimint_client_module/transaction/
builder.rs

1use std::fmt;
2use std::ops::RangeInclusive;
3use std::sync::Arc;
4
5use bitcoin::key::Keypair;
6use bitcoin::secp256k1;
7use fedimint_core::Amount;
8use fedimint_core::core::{
9    DynInput, DynOutput, IInput, IOutput, IntoDynInstance, ModuleInstanceId,
10};
11use fedimint_core::encoding::{Decodable, Encodable};
12use fedimint_core::task::{MaybeSend, MaybeSync};
13use fedimint_core::transaction::{Transaction, TransactionSignature};
14use fedimint_logging::LOG_CLIENT;
15use itertools::multiunzip;
16use rand::{CryptoRng, Rng, RngCore};
17use secp256k1::Secp256k1;
18use tracing::warn;
19
20use crate::module::{IdxRange, OutPointRange, StateGenerator};
21use crate::sm::{self, DynState};
22use crate::{
23    InstancelessDynClientInput, InstancelessDynClientInputBundle, InstancelessDynClientInputSM,
24    InstancelessDynClientOutput, InstancelessDynClientOutputBundle, InstancelessDynClientOutputSM,
25    states_add_instance, states_to_instanceless_dyn,
26};
27
28#[derive(Clone, Debug)]
29pub struct ClientInput<I = DynInput> {
30    pub input: I,
31    pub keys: Vec<Keypair>,
32    pub amount: Amount,
33}
34
35#[derive(Clone)]
36pub struct ClientInputSM<S = DynState> {
37    pub state_machines: StateGenerator<S>,
38}
39
40impl<S> fmt::Debug for ClientInputSM<S> {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_str("ClientInputSM")
43    }
44}
45
46/// A fake [`sm::Context`] for [`NeverClientStateMachine`]
47#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
48pub enum NeverClientContext {}
49
50impl sm::Context for NeverClientContext {
51    const KIND: Option<fedimint_core::core::ModuleKind> = None;
52}
53
54/// A fake [`sm::State`] that can actually never happen.
55///
56/// Useful as a default for type inference in cases where there are no
57/// state machines involved in [`ClientInputBundle`].
58#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
59pub enum NeverClientStateMachine {}
60
61impl IntoDynInstance for NeverClientStateMachine {
62    type DynType = DynState;
63
64    fn into_dyn(self, _instance_id: ModuleInstanceId) -> Self::DynType {
65        unreachable!()
66    }
67}
68impl sm::State for NeverClientStateMachine {
69    type ModuleContext = NeverClientContext;
70
71    fn transitions(
72        &self,
73        _context: &Self::ModuleContext,
74        _global_context: &crate::DynGlobalClientContext,
75    ) -> Vec<sm::StateTransition<Self>> {
76        unreachable!()
77    }
78
79    fn operation_id(&self) -> fedimint_core::core::OperationId {
80        unreachable!()
81    }
82}
83
84/// A group of inputs and state machines responsible for driving their state
85///
86/// These must be kept together as a whole when including in a transaction.
87#[derive(Clone, Debug)]
88pub struct ClientInputBundle<I = DynInput, S = DynState> {
89    pub(crate) inputs: Vec<ClientInput<I>>,
90    pub(crate) sm_gens: Vec<ClientInputSM<S>>,
91}
92
93impl<I> ClientInputBundle<I, NeverClientStateMachine> {
94    /// A version of [`Self::new`] for times where input does not require any
95    /// state machines
96    ///
97    /// This avoids type inference issues of `S`, and saves some typing.
98    pub fn new_no_sm(inputs: Vec<ClientInput<I>>) -> Self {
99        if inputs.is_empty() {
100            // TODO: Make it return Result or assert?
101            warn!(target: LOG_CLIENT, "Empty input bundle will be illegal in the future");
102        }
103        Self {
104            inputs,
105            sm_gens: vec![],
106        }
107    }
108}
109
110impl<I, S> ClientInputBundle<I, S>
111where
112    I: IInput + MaybeSend + MaybeSync + 'static,
113    S: sm::IState + MaybeSend + MaybeSync + 'static,
114{
115    pub fn new(inputs: Vec<ClientInput<I>>, sm_gens: Vec<ClientInputSM<S>>) -> Self {
116        Self { inputs, sm_gens }
117    }
118
119    pub fn inputs(&self) -> &[ClientInput<I>] {
120        &self.inputs
121    }
122
123    pub fn sms(&self) -> &[ClientInputSM<S>] {
124        &self.sm_gens
125    }
126
127    pub fn into_instanceless(self) -> InstancelessDynClientInputBundle {
128        InstancelessDynClientInputBundle {
129            inputs: self
130                .inputs
131                .into_iter()
132                .map(|input| InstancelessDynClientInput {
133                    input: Box::new(input.input),
134                    keys: input.keys,
135                    amount: input.amount,
136                })
137                .collect(),
138            sm_gens: self
139                .sm_gens
140                .into_iter()
141                .map(|input_sm| InstancelessDynClientInputSM {
142                    state_machines: states_to_instanceless_dyn(input_sm.state_machines),
143                })
144                .collect(),
145        }
146    }
147}
148
149impl<I, S> ClientInputBundle<I, S> {
150    pub fn is_empty(&self) -> bool {
151        // Notably, sm_gen will not be called when inputs are empty anyway
152        self.inputs.is_empty()
153    }
154}
155
156impl<I> IntoDynInstance for ClientInput<I>
157where
158    I: IntoDynInstance<DynType = DynInput> + 'static,
159{
160    type DynType = ClientInput;
161
162    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInput {
163        ClientInput {
164            input: self.input.into_dyn(module_instance_id),
165            keys: self.keys,
166            amount: self.amount,
167        }
168    }
169}
170
171impl<S> IntoDynInstance for ClientInputSM<S>
172where
173    S: IntoDynInstance<DynType = DynState> + 'static,
174{
175    type DynType = ClientInputSM;
176
177    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputSM {
178        ClientInputSM {
179            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
180        }
181    }
182}
183
184impl<I, S> IntoDynInstance for ClientInputBundle<I, S>
185where
186    I: IntoDynInstance<DynType = DynInput> + 'static,
187    S: IntoDynInstance<DynType = DynState> + 'static,
188{
189    type DynType = ClientInputBundle;
190
191    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
192        ClientInputBundle {
193            inputs: self
194                .inputs
195                .into_iter()
196                .map(|input| input.into_dyn(module_instance_id))
197                .collect::<Vec<ClientInput>>(),
198
199            sm_gens: self
200                .sm_gens
201                .into_iter()
202                .map(|input_sm| input_sm.into_dyn(module_instance_id))
203                .collect::<Vec<ClientInputSM>>(),
204        }
205    }
206}
207
208impl IntoDynInstance for InstancelessDynClientInputBundle {
209    type DynType = ClientInputBundle;
210
211    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
212        ClientInputBundle {
213            inputs: self
214                .inputs
215                .into_iter()
216                .map(|input| ClientInput {
217                    input: DynInput::from_parts(module_instance_id, input.input),
218                    keys: input.keys,
219                    amount: input.amount,
220                })
221                .collect::<Vec<ClientInput>>(),
222
223            sm_gens: self
224                .sm_gens
225                .into_iter()
226                .map(|input_sm| ClientInputSM {
227                    state_machines: states_add_instance(
228                        module_instance_id,
229                        input_sm.state_machines,
230                    ),
231                })
232                .collect::<Vec<ClientInputSM>>(),
233        }
234    }
235}
236
237#[derive(Clone, Debug)]
238pub struct ClientOutputBundle<O = DynOutput, S = DynState> {
239    pub(crate) outputs: Vec<ClientOutput<O>>,
240    pub(crate) sm_gens: Vec<ClientOutputSM<S>>,
241}
242
243#[derive(Clone, Debug)]
244pub struct ClientOutput<O = DynOutput> {
245    pub output: O,
246    pub amount: Amount,
247}
248
249#[derive(Clone)]
250pub struct ClientOutputSM<S = DynState> {
251    pub state_machines: StateGenerator<S>,
252}
253
254impl<S> fmt::Debug for ClientOutputSM<S> {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        f.write_str("ClientOutputSM")
257    }
258}
259impl<O> ClientOutputBundle<O, NeverClientStateMachine> {
260    /// A version of [`Self::new`] for times where output does not require any
261    /// state machines
262    ///
263    /// This avoids type inference issues of `S`, and saves some typing.
264    pub fn new_no_sm(outputs: Vec<ClientOutput<O>>) -> Self {
265        if outputs.is_empty() {
266            // TODO: Make it return Result or assert?
267            warn!(target: LOG_CLIENT, "Empty output bundle will be illegal in the future");
268        }
269        Self {
270            outputs,
271            sm_gens: vec![],
272        }
273    }
274}
275impl<O, S> ClientOutputBundle<O, S> {
276    pub fn outputs(&self) -> &[ClientOutput<O>] {
277        &self.outputs
278    }
279}
280
281impl<O, S> ClientOutputBundle<O, S>
282where
283    O: IOutput + MaybeSend + MaybeSync + 'static,
284    S: sm::IState + MaybeSend + MaybeSync + 'static,
285{
286    pub fn new(outputs: Vec<ClientOutput<O>>, sm_gens: Vec<ClientOutputSM<S>>) -> Self {
287        Self { outputs, sm_gens }
288    }
289
290    pub fn sms(&self) -> &[ClientOutputSM<S>] {
291        &self.sm_gens
292    }
293
294    pub fn with(mut self, other: Self) -> Self {
295        self.outputs.extend(other.outputs);
296        self.sm_gens.extend(other.sm_gens);
297        self
298    }
299
300    pub fn into_instanceless(self) -> InstancelessDynClientOutputBundle {
301        InstancelessDynClientOutputBundle {
302            outputs: self
303                .outputs
304                .into_iter()
305                .map(|output| InstancelessDynClientOutput {
306                    output: Box::new(output.output),
307                    amount: output.amount,
308                })
309                .collect(),
310            sm_gens: self
311                .sm_gens
312                .into_iter()
313                .map(|output_sm| InstancelessDynClientOutputSM {
314                    state_machines: states_to_instanceless_dyn(output_sm.state_machines),
315                })
316                .collect(),
317        }
318    }
319}
320
321impl<O, S> ClientOutputBundle<O, S> {
322    pub fn is_empty(&self) -> bool {
323        // Notably, sm_gen will not be called when outputs are empty anyway
324        self.outputs.is_empty()
325    }
326}
327
328impl<I, S> IntoDynInstance for ClientOutputBundle<I, S>
329where
330    I: IntoDynInstance<DynType = DynOutput> + 'static,
331    S: IntoDynInstance<DynType = DynState> + 'static,
332{
333    type DynType = ClientOutputBundle;
334
335    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
336        ClientOutputBundle {
337            outputs: self
338                .outputs
339                .into_iter()
340                .map(|output| output.into_dyn(module_instance_id))
341                .collect::<Vec<ClientOutput>>(),
342
343            sm_gens: self
344                .sm_gens
345                .into_iter()
346                .map(|output_sm| output_sm.into_dyn(module_instance_id))
347                .collect::<Vec<ClientOutputSM>>(),
348        }
349    }
350}
351
352impl IntoDynInstance for InstancelessDynClientOutputBundle {
353    type DynType = ClientOutputBundle;
354
355    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
356        ClientOutputBundle {
357            outputs: self
358                .outputs
359                .into_iter()
360                .map(|output| ClientOutput {
361                    output: DynOutput::from_parts(module_instance_id, output.output),
362                    amount: output.amount,
363                })
364                .collect::<Vec<ClientOutput>>(),
365
366            sm_gens: self
367                .sm_gens
368                .into_iter()
369                .map(|output_sm| ClientOutputSM {
370                    state_machines: states_add_instance(
371                        module_instance_id,
372                        output_sm.state_machines,
373                    ),
374                })
375                .collect::<Vec<ClientOutputSM>>(),
376        }
377    }
378}
379
380impl<I> IntoDynInstance for ClientOutput<I>
381where
382    I: IntoDynInstance<DynType = DynOutput> + 'static,
383{
384    type DynType = ClientOutput;
385
386    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutput {
387        ClientOutput {
388            output: self.output.into_dyn(module_instance_id),
389            amount: self.amount,
390        }
391    }
392}
393
394impl<S> IntoDynInstance for ClientOutputSM<S>
395where
396    S: IntoDynInstance<DynType = DynState> + 'static,
397{
398    type DynType = ClientOutputSM;
399
400    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputSM {
401        ClientOutputSM {
402            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
403        }
404    }
405}
406
407#[derive(Default, Clone, Debug)]
408pub struct TransactionBuilder {
409    inputs: Vec<ClientInputBundle>,
410    outputs: Vec<ClientOutputBundle>,
411}
412
413impl TransactionBuilder {
414    pub fn new() -> Self {
415        Self::default()
416    }
417
418    pub fn with_inputs(mut self, inputs: ClientInputBundle) -> Self {
419        self.inputs.push(inputs);
420        self
421    }
422
423    pub fn with_outputs(mut self, outputs: ClientOutputBundle) -> Self {
424        self.outputs.push(outputs);
425        self
426    }
427
428    pub fn build<C, R: RngCore + CryptoRng>(
429        self,
430        secp_ctx: &Secp256k1<C>,
431        mut rng: R,
432    ) -> (Transaction, Vec<DynState>)
433    where
434        C: secp256k1::Signing + secp256k1::Verification,
435    {
436        // `input_idx_to_bundle_idx[input_idx]` stores the index of a bundle the input
437        // at `input_idx` comes from, so we can call state machines of the
438        // corresponding bundle for every input bundle. It is always
439        // monotonically increasing, e.g. `[0, 0, 1, 2, 2, 2, 4]`
440        let (input_idx_to_bundle_idx, inputs, input_keys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
441            self.inputs
442                .iter()
443                .enumerate()
444                .flat_map(|(bundle_idx, bundle)| {
445                    bundle
446                        .inputs
447                        .iter()
448                        .map(move |input| (bundle_idx, input.input.clone(), input.keys.clone()))
449                }),
450        );
451        // `output_idx_to_bundle` works exactly like `input_idx_to_bundle_idx` above,
452        // but for outputs.
453        let (output_idx_to_bundle_idx, outputs): (Vec<_>, Vec<_>) = multiunzip(
454            self.outputs
455                .iter()
456                .enumerate()
457                .flat_map(|(bundle_idx, bundle)| {
458                    bundle
459                        .outputs
460                        .iter()
461                        .map(move |output| (bundle_idx, output.output.clone()))
462                }),
463        );
464        let nonce: [u8; 8] = rng.r#gen();
465
466        let txid = Transaction::tx_hash_from_parts(&inputs, &outputs, nonce);
467        let msg = secp256k1::Message::from_digest_slice(&txid[..]).expect("txid has right length");
468
469        let signatures = input_keys
470            .iter()
471            .flatten()
472            .map(|keypair| secp_ctx.sign_schnorr(&msg, keypair))
473            .collect();
474
475        let transaction = Transaction {
476            inputs,
477            outputs,
478            nonce,
479            signatures: TransactionSignature::NaiveMultisig(signatures),
480        };
481
482        let input_states = self
483            .inputs
484            .into_iter()
485            .enumerate()
486            .filter(|(_, bundle)| !bundle.is_empty())
487            .flat_map(|(bundle_idx, bundle)| {
488                let input_idxs = find_range_of_matching_items(&input_idx_to_bundle_idx, bundle_idx)
489                    .expect("Non empty bundles must always have a match");
490                bundle.sm_gens.into_iter().flat_map(move |sm| {
491                    (sm.state_machines)(OutPointRange::new(
492                        txid,
493                        IdxRange::from_inclusive(input_idxs.clone()).expect("can't overflow"),
494                    ))
495                })
496            });
497
498        let output_states = self
499            .outputs
500            .into_iter()
501            .enumerate()
502            .filter(|(_, bundle)| !bundle.is_empty())
503            .flat_map(|(bundle_idx, bundle)| {
504                let output_idxs =
505                    find_range_of_matching_items(&output_idx_to_bundle_idx, bundle_idx)
506                        .expect("Non empty bundles must always have a match");
507                bundle.sm_gens.into_iter().flat_map(move |sm| {
508                    (sm.state_machines)(OutPointRange::new(
509                        txid,
510                        IdxRange::from_inclusive(output_idxs.clone())
511                            .expect("can't possibly overflow"),
512                    ))
513                })
514            });
515        (transaction, input_states.chain(output_states).collect())
516    }
517
518    pub fn inputs(&self) -> impl Iterator<Item = &ClientInput> {
519        self.inputs.iter().flat_map(|i| i.inputs.iter())
520    }
521
522    pub fn outputs(&self) -> impl Iterator<Item = &ClientOutput> {
523        self.outputs.iter().flat_map(|i| i.outputs.iter())
524    }
525}
526
527/// Find the range of indexes in an monotonically increasing `arr`, that is
528/// equal to `item`
529fn find_range_of_matching_items(arr: &[usize], item: usize) -> Option<RangeInclusive<u64>> {
530    // `arr` must be monotonically increasing
531    debug_assert!(arr.windows(2).all(|w| w[0] <= w[1]));
532
533    arr.iter()
534        .enumerate()
535        .filter_map(|(arr_idx, arr_item)| (*arr_item == item).then_some(arr_idx as u64))
536        .fold(None, |cur: Option<(u64, u64)>, idx| {
537            Some(cur.map_or((idx, idx), |cur| (cur.0.min(idx), cur.1.max(idx))))
538        })
539        .map(|(start, end)| start..=end)
540}
541
542#[test]
543fn find_range_of_matching_items_sanity() {
544    assert_eq!(find_range_of_matching_items(&[0, 0], 0), Some(0..=1));
545    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 0), Some(0..=1));
546    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 1), Some(2..=2));
547    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 2), None);
548    assert_eq!(find_range_of_matching_items(&[], 0), None);
549}
550
551fn state_gen_to_dyn<S>(
552    state_gen: StateGenerator<S>,
553    module_instance: ModuleInstanceId,
554) -> StateGenerator<DynState>
555where
556    S: IntoDynInstance<DynType = DynState> + 'static,
557{
558    Arc::new(move |out_point_range| {
559        let states = state_gen(out_point_range);
560        states
561            .into_iter()
562            .map(|state| state.into_dyn(module_instance))
563            .collect()
564    })
565}
566
567#[cfg(test)]
568mod tests;