Skip to main content

fedimint_client_module/sm/
state.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::future::Future;
4use std::hash;
5use std::io::{Error, Read, Write};
6use std::pin::Pin;
7use std::sync::Arc;
8
9use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, ModuleKind, OperationId};
10use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
11use fedimint_core::module::registry::ModuleDecoderRegistry;
12use fedimint_core::task::{MaybeSend, MaybeSync};
13use fedimint_core::util::BoxFuture;
14use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
15
16use crate::DynGlobalClientContext;
17use crate::sm::ClientSMDatabaseTransaction;
18
19/// Implementors act as state machines that can be executed
20pub trait State:
21    Debug
22    + Clone
23    + Eq
24    + PartialEq
25    + std::hash::Hash
26    + Encodable
27    + Decodable
28    + MaybeSend
29    + MaybeSync
30    + 'static
31{
32    /// Additional resources made available in this module's state transitions
33    type ModuleContext: Context;
34
35    /// All possible transitions from the current state to other states. See
36    /// [`StateTransition`] for details.
37    fn transitions(
38        &self,
39        context: &Self::ModuleContext,
40        global_context: &DynGlobalClientContext,
41    ) -> Vec<StateTransition<Self>>;
42
43    // TODO: move out of this interface into wrapper struct (see OperationState)
44    /// Operation this state machine belongs to. See [`OperationId`] for
45    /// details.
46    fn operation_id(&self) -> OperationId;
47
48    /// Human-readable visualization of this state machine for debugging.
49    ///
50    /// Defaults to `Debug` output. Override to show only the relevant
51    /// fields instead of dumping raw crypto bytes.
52    fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
53        write!(f, "{indent}{self:?}")
54    }
55}
56
57/// Object-safe version of [`State`]
58pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
59    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
60
61    /// All possible transitions from the state
62    fn transitions(
63        &self,
64        context: &DynContext,
65        global_context: &DynGlobalClientContext,
66    ) -> Vec<StateTransition<DynState>>;
67
68    /// Operation this state machine belongs to. See [`OperationId`] for
69    /// details.
70    fn operation_id(&self) -> OperationId;
71
72    /// Clone state
73    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
74
75    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
76
77    fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
78
79    /// Human-readable visualization for debugging. See
80    /// [`State::fmt_visualization`].
81    fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result;
82}
83
84/// Something that can be a [`DynContext`] for a state machine
85///
86/// General purpose code should use [`DynContext`] instead
87pub trait IContext: Debug {
88    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
89    fn module_kind(&self) -> Option<ModuleKind>;
90}
91
92module_plugin_dyn_newtype_define! {
93    /// A shared context for a module client state machine
94    #[derive(Clone)]
95    pub DynContext(Arc<IContext>)
96}
97
98/// Additional data made available to state machines of a module (e.g. API
99/// clients)
100pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {
101    const KIND: Option<ModuleKind>;
102}
103
104/// Type-erased version of [`Context`]
105impl<T> IContext for T
106where
107    T: Context + 'static + MaybeSend + MaybeSync,
108{
109    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
110        self
111    }
112
113    fn module_kind(&self) -> Option<ModuleKind> {
114        T::KIND
115    }
116}
117
118type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
119
120// TODO: remove Arc, maybe make it a fn pointer?
121pub type StateTransitionFunction<S> = Arc<
122    maybe_add_send_sync!(
123        dyn for<'a> Fn(
124            &'a mut ClientSMDatabaseTransaction<'_, '_>,
125            serde_json::Value,
126            S,
127        ) -> BoxFuture<'a, S>
128    ),
129>;
130
131/// Represents one or multiple possible state transitions triggered in a common
132/// way
133pub struct StateTransition<S> {
134    /// Future that will block until a state transition is possible.
135    ///
136    /// **The trigger future must be idempotent since it might be re-run if the
137    /// client is restarted.**
138    ///
139    /// To wait for a possible state transition it can query external APIs,
140    /// subscribe to events emitted by other state machines, etc.
141    /// Optionally, it can also return some data that will be given to the
142    /// state transition function, see the `transition` docs for details.
143    pub trigger: TriggerFuture,
144    /// State transition function that, using the output of the `trigger`,
145    /// performs the appropriate state transition.
146    ///
147    /// **This function shall not block on network IO or similar things as all
148    /// actual state transitions are run serially.**
149    ///
150    /// Since the this function can return different output states depending on
151    /// the `Value` returned by the `trigger` future it can be used to model
152    /// multiple possible state transition at once. E.g. instead of having
153    /// two state transitions querying the same API endpoint and each waiting
154    /// for a specific value to be returned to trigger their respective state
155    /// transition we can have one `trigger` future querying the API and
156    /// depending on the return value run different state transitions,
157    /// saving network requests.
158    pub transition: StateTransitionFunction<S>,
159}
160
161impl<S> StateTransition<S> {
162    /// Creates a new `StateTransition` where the `trigger` future returns a
163    /// value of type `V` that is then given to the `transition` function.
164    pub fn new<V, Trigger, TransitionFn>(
165        trigger: Trigger,
166        transition: TransitionFn,
167    ) -> StateTransition<S>
168    where
169        S: MaybeSend + MaybeSync + Clone + 'static,
170        V: serde::Serialize + serde::de::DeserializeOwned + Send,
171        Trigger: Future<Output = V> + MaybeSend + 'static,
172        TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
173            + MaybeSend
174            + MaybeSync
175            + Clone
176            + 'static,
177    {
178        StateTransition {
179            trigger: Box::pin(async {
180                let val = trigger.await;
181                serde_json::to_value(val).expect("Value could not be serialized")
182            }),
183            transition: Arc::new(move |dbtx, val, state| {
184                let transition = transition.clone();
185                Box::pin(async move {
186                    let typed_val: V = serde_json::from_value(val)
187                        .expect("Deserialize trigger return value failed");
188                    transition(dbtx, typed_val, state.clone()).await
189                })
190            }),
191        }
192    }
193}
194
195impl<T> IState for T
196where
197    T: State,
198{
199    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
200        self
201    }
202
203    fn transitions(
204        &self,
205        context: &DynContext,
206        global_context: &DynGlobalClientContext,
207    ) -> Vec<StateTransition<DynState>> {
208        <T as State>::transitions(
209            self,
210            context.as_any().downcast_ref().expect("Wrong module"),
211            global_context,
212        )
213        .into_iter()
214        .map(|st| StateTransition {
215            trigger: st.trigger,
216            transition: Arc::new(
217                move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
218                    let transition = st.transition.clone();
219                    Box::pin(async move {
220                        let new_state = transition(
221                            dbtx,
222                            val,
223                            state
224                                .as_any()
225                                .downcast_ref::<T>()
226                                .expect("Wrong module")
227                                .clone(),
228                        )
229                        .await;
230                        DynState::from_typed(state.module_instance_id(), new_state)
231                    })
232                },
233            ),
234        })
235        .collect()
236    }
237
238    fn operation_id(&self) -> OperationId {
239        <T as State>::operation_id(self)
240    }
241
242    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
243        DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
244    }
245
246    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
247        let other: &T = other
248            .as_any()
249            .downcast_ref()
250            .expect("Type is ensured in previous step");
251
252        self == other
253    }
254
255    fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
256        self.hash(&mut hasher);
257    }
258
259    fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
260        <T as State>::fmt_visualization(self, f, indent)
261    }
262}
263
264/// A type-erased state of a state machine belonging to a module instance, see
265/// [`State`]
266pub struct DynState(
267    Box<maybe_add_send_sync!(dyn IState + 'static)>,
268    ModuleInstanceId,
269);
270
271impl IState for DynState {
272    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
273        (**self).as_any()
274    }
275
276    fn transitions(
277        &self,
278        context: &DynContext,
279        global_context: &DynGlobalClientContext,
280    ) -> Vec<StateTransition<DynState>> {
281        (**self).transitions(context, global_context)
282    }
283
284    fn operation_id(&self) -> OperationId {
285        (**self).operation_id()
286    }
287
288    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
289        (**self).clone(module_instance_id)
290    }
291
292    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
293        (**self).erased_eq_no_instance_id(other)
294    }
295
296    fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher) {
297        (**self).erased_hash_no_instance_id(hasher);
298    }
299
300    fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
301        (**self).fmt_visualization(f, indent)
302    }
303}
304
305impl IntoDynInstance for DynState {
306    type DynType = DynState;
307
308    fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
309        assert_eq!(instance_id, self.1);
310        self
311    }
312}
313
314impl std::ops::Deref for DynState {
315    type Target = maybe_add_send_sync!(dyn IState + 'static);
316
317    fn deref(&self) -> &<Self as std::ops::Deref>::Target {
318        &*self.0
319    }
320}
321
322impl hash::Hash for DynState {
323    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
324        self.1.hash(hasher);
325        self.0.erased_hash_no_instance_id(hasher);
326    }
327}
328
329impl DynState {
330    pub fn module_instance_id(&self) -> ModuleInstanceId {
331        self.1
332    }
333
334    pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
335    where
336        I: IState + 'static,
337    {
338        Self(Box::new(typed), module_instance_id)
339    }
340
341    pub fn from_parts(
342        module_instance_id: ::fedimint_core::core::ModuleInstanceId,
343        dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
344    ) -> Self {
345        Self(dynbox, module_instance_id)
346    }
347}
348
349impl std::fmt::Debug for DynState {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        std::fmt::Debug::fmt(&self.0, f)
352    }
353}
354
355impl std::ops::DerefMut for DynState {
356    fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
357        &mut *self.0
358    }
359}
360
361impl Clone for DynState {
362    fn clone(&self) -> Self {
363        self.0.clone(self.1)
364    }
365}
366
367impl PartialEq for DynState {
368    fn eq(&self, other: &Self) -> bool {
369        if self.1 != other.1 {
370            return false;
371        }
372        self.erased_eq_no_instance_id(other)
373    }
374}
375impl Eq for DynState {}
376
377impl Encodable for DynState {
378    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
379        self.1.consensus_encode(writer)?;
380        self.0.consensus_encode_dyn(writer)
381    }
382}
383impl Decodable for DynState {
384    fn consensus_decode_partial<R: std::io::Read>(
385        reader: &mut R,
386        decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
387    ) -> Result<Self, fedimint_core::encoding::DecodeError> {
388        let module_id =
389            fedimint_core::core::ModuleInstanceId::consensus_decode_partial(reader, decoders)?;
390        decoders
391            .get_expect(module_id)
392            .decode_partial(reader, module_id, decoders)
393    }
394}
395
396impl DynState {
397    /// Return a human-readable visualization string for debugging.
398    pub fn visualization(&self, indent: &str) -> String {
399        let mut s = String::new();
400        self.fmt_visualization(&mut s, indent)
401            .expect("fmt::Write to String never fails");
402        s
403    }
404
405    /// `true` if this state allows no further transitions
406    pub fn is_terminal(
407        &self,
408        context: &DynContext,
409        global_context: &DynGlobalClientContext,
410    ) -> bool {
411        self.transitions(context, global_context).is_empty()
412    }
413}
414
415#[derive(Debug)]
416pub struct OperationState<S> {
417    pub operation_id: OperationId,
418    pub state: S,
419}
420
421/// Wrapper for states that don't want to carry around their operation id. `S`
422/// is allowed to panic when `operation_id` is called.
423impl<S> State for OperationState<S>
424where
425    S: State,
426{
427    type ModuleContext = S::ModuleContext;
428
429    fn transitions(
430        &self,
431        context: &Self::ModuleContext,
432        global_context: &DynGlobalClientContext,
433    ) -> Vec<StateTransition<Self>> {
434        let transitions: Vec<StateTransition<OperationState<S>>> = self
435            .state
436            .transitions(context, global_context)
437            .into_iter()
438            .map(
439                |StateTransition {
440                     trigger,
441                     transition,
442                 }| {
443                    let op_transition: StateTransitionFunction<Self> =
444                        Arc::new(move |dbtx, value, op_state| {
445                            let transition = transition.clone();
446                            Box::pin(async move {
447                                let state = transition(dbtx, value, op_state.state).await;
448                                OperationState {
449                                    operation_id: op_state.operation_id,
450                                    state,
451                                }
452                            })
453                        });
454
455                    StateTransition {
456                        trigger,
457                        transition: op_transition,
458                    }
459                },
460            )
461            .collect();
462        transitions
463    }
464
465    fn operation_id(&self) -> OperationId {
466        self.operation_id
467    }
468
469    fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
470        self.state.fmt_visualization(f, indent)
471    }
472}
473
474// TODO: can we get rid of `GC`? Maybe make it an associated type of `State`
475// instead?
476impl<S> IntoDynInstance for OperationState<S>
477where
478    S: State,
479{
480    type DynType = DynState;
481
482    fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
483        DynState::from_typed(instance_id, self)
484    }
485}
486
487impl<S> Encodable for OperationState<S>
488where
489    S: State,
490{
491    fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<(), Error> {
492        self.operation_id.consensus_encode(writer)?;
493        self.state.consensus_encode(writer)?;
494        Ok(())
495    }
496}
497
498impl<S> Decodable for OperationState<S>
499where
500    S: State,
501{
502    fn consensus_decode_partial<R: Read>(
503        read: &mut R,
504        modules: &ModuleDecoderRegistry,
505    ) -> Result<Self, DecodeError> {
506        let operation_id = OperationId::consensus_decode_partial(read, modules)?;
507        let state = S::consensus_decode_partial(read, modules)?;
508
509        Ok(OperationState {
510            operation_id,
511            state,
512        })
513    }
514}
515
516// TODO: derive after getting rid of `GC` type arg
517impl<S> PartialEq for OperationState<S>
518where
519    S: State,
520{
521    fn eq(&self, other: &Self) -> bool {
522        self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
523    }
524}
525
526impl<S> Eq for OperationState<S> where S: State {}
527
528impl<S> hash::Hash for OperationState<S>
529where
530    S: hash::Hash,
531{
532    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
533        self.operation_id.hash(hasher);
534        self.state.hash(hasher);
535    }
536}
537
538impl<S> Clone for OperationState<S>
539where
540    S: State,
541{
542    fn clone(&self) -> Self {
543        OperationState {
544            operation_id: self.operation_id,
545            state: self.state.clone(),
546        }
547    }
548}