fedimint_client_module/sm/
state.rs1use std::any::Any;
2use std::fmt::Debug;
3use std::future::Future;
4use std::hash;
5use std::io::{Error, Read, Write};
6use std::pin::Pin;
7use std::sync::Arc;
8
9use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, ModuleKind, OperationId};
10use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
11use fedimint_core::module::registry::ModuleDecoderRegistry;
12use fedimint_core::task::{MaybeSend, MaybeSync};
13use fedimint_core::util::BoxFuture;
14use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
15
16use crate::DynGlobalClientContext;
17use crate::sm::ClientSMDatabaseTransaction;
18
19pub trait State:
21 Debug
22 + Clone
23 + Eq
24 + PartialEq
25 + std::hash::Hash
26 + Encodable
27 + Decodable
28 + MaybeSend
29 + MaybeSync
30 + 'static
31{
32 type ModuleContext: Context;
34
35 fn transitions(
38 &self,
39 context: &Self::ModuleContext,
40 global_context: &DynGlobalClientContext,
41 ) -> Vec<StateTransition<Self>>;
42
43 fn operation_id(&self) -> OperationId;
47
48 fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
53 write!(f, "{indent}{self:?}")
54 }
55}
56
57pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
59 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
60
61 fn transitions(
63 &self,
64 context: &DynContext,
65 global_context: &DynGlobalClientContext,
66 ) -> Vec<StateTransition<DynState>>;
67
68 fn operation_id(&self) -> OperationId;
71
72 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
74
75 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
76
77 fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
78
79 fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result;
82}
83
84pub trait IContext: Debug {
88 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
89 fn module_kind(&self) -> Option<ModuleKind>;
90}
91
92module_plugin_dyn_newtype_define! {
93 #[derive(Clone)]
95 pub DynContext(Arc<IContext>)
96}
97
98pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {
101 const KIND: Option<ModuleKind>;
102}
103
104impl<T> IContext for T
106where
107 T: Context + 'static + MaybeSend + MaybeSync,
108{
109 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
110 self
111 }
112
113 fn module_kind(&self) -> Option<ModuleKind> {
114 T::KIND
115 }
116}
117
118type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
119
120pub type StateTransitionFunction<S> = Arc<
122 maybe_add_send_sync!(
123 dyn for<'a> Fn(
124 &'a mut ClientSMDatabaseTransaction<'_, '_>,
125 serde_json::Value,
126 S,
127 ) -> BoxFuture<'a, S>
128 ),
129>;
130
131pub struct StateTransition<S> {
134 pub trigger: TriggerFuture,
144 pub transition: StateTransitionFunction<S>,
159}
160
161impl<S> StateTransition<S> {
162 pub fn new<V, Trigger, TransitionFn>(
165 trigger: Trigger,
166 transition: TransitionFn,
167 ) -> StateTransition<S>
168 where
169 S: MaybeSend + MaybeSync + Clone + 'static,
170 V: serde::Serialize + serde::de::DeserializeOwned + Send,
171 Trigger: Future<Output = V> + MaybeSend + 'static,
172 TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
173 + MaybeSend
174 + MaybeSync
175 + Clone
176 + 'static,
177 {
178 StateTransition {
179 trigger: Box::pin(async {
180 let val = trigger.await;
181 serde_json::to_value(val).expect("Value could not be serialized")
182 }),
183 transition: Arc::new(move |dbtx, val, state| {
184 let transition = transition.clone();
185 Box::pin(async move {
186 let typed_val: V = serde_json::from_value(val)
187 .expect("Deserialize trigger return value failed");
188 transition(dbtx, typed_val, state.clone()).await
189 })
190 }),
191 }
192 }
193}
194
195impl<T> IState for T
196where
197 T: State,
198{
199 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
200 self
201 }
202
203 fn transitions(
204 &self,
205 context: &DynContext,
206 global_context: &DynGlobalClientContext,
207 ) -> Vec<StateTransition<DynState>> {
208 <T as State>::transitions(
209 self,
210 context.as_any().downcast_ref().expect("Wrong module"),
211 global_context,
212 )
213 .into_iter()
214 .map(|st| StateTransition {
215 trigger: st.trigger,
216 transition: Arc::new(
217 move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
218 let transition = st.transition.clone();
219 Box::pin(async move {
220 let new_state = transition(
221 dbtx,
222 val,
223 state
224 .as_any()
225 .downcast_ref::<T>()
226 .expect("Wrong module")
227 .clone(),
228 )
229 .await;
230 DynState::from_typed(state.module_instance_id(), new_state)
231 })
232 },
233 ),
234 })
235 .collect()
236 }
237
238 fn operation_id(&self) -> OperationId {
239 <T as State>::operation_id(self)
240 }
241
242 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
243 DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
244 }
245
246 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
247 let other: &T = other
248 .as_any()
249 .downcast_ref()
250 .expect("Type is ensured in previous step");
251
252 self == other
253 }
254
255 fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
256 self.hash(&mut hasher);
257 }
258
259 fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
260 <T as State>::fmt_visualization(self, f, indent)
261 }
262}
263
264pub struct DynState(
267 Box<maybe_add_send_sync!(dyn IState + 'static)>,
268 ModuleInstanceId,
269);
270
271impl IState for DynState {
272 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
273 (**self).as_any()
274 }
275
276 fn transitions(
277 &self,
278 context: &DynContext,
279 global_context: &DynGlobalClientContext,
280 ) -> Vec<StateTransition<DynState>> {
281 (**self).transitions(context, global_context)
282 }
283
284 fn operation_id(&self) -> OperationId {
285 (**self).operation_id()
286 }
287
288 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
289 (**self).clone(module_instance_id)
290 }
291
292 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
293 (**self).erased_eq_no_instance_id(other)
294 }
295
296 fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher) {
297 (**self).erased_hash_no_instance_id(hasher);
298 }
299
300 fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
301 (**self).fmt_visualization(f, indent)
302 }
303}
304
305impl IntoDynInstance for DynState {
306 type DynType = DynState;
307
308 fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
309 assert_eq!(instance_id, self.1);
310 self
311 }
312}
313
314impl std::ops::Deref for DynState {
315 type Target = maybe_add_send_sync!(dyn IState + 'static);
316
317 fn deref(&self) -> &<Self as std::ops::Deref>::Target {
318 &*self.0
319 }
320}
321
322impl hash::Hash for DynState {
323 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
324 self.1.hash(hasher);
325 self.0.erased_hash_no_instance_id(hasher);
326 }
327}
328
329impl DynState {
330 pub fn module_instance_id(&self) -> ModuleInstanceId {
331 self.1
332 }
333
334 pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
335 where
336 I: IState + 'static,
337 {
338 Self(Box::new(typed), module_instance_id)
339 }
340
341 pub fn from_parts(
342 module_instance_id: ::fedimint_core::core::ModuleInstanceId,
343 dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
344 ) -> Self {
345 Self(dynbox, module_instance_id)
346 }
347}
348
349impl std::fmt::Debug for DynState {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 std::fmt::Debug::fmt(&self.0, f)
352 }
353}
354
355impl std::ops::DerefMut for DynState {
356 fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
357 &mut *self.0
358 }
359}
360
361impl Clone for DynState {
362 fn clone(&self) -> Self {
363 self.0.clone(self.1)
364 }
365}
366
367impl PartialEq for DynState {
368 fn eq(&self, other: &Self) -> bool {
369 if self.1 != other.1 {
370 return false;
371 }
372 self.erased_eq_no_instance_id(other)
373 }
374}
375impl Eq for DynState {}
376
377impl Encodable for DynState {
378 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
379 self.1.consensus_encode(writer)?;
380 self.0.consensus_encode_dyn(writer)
381 }
382}
383impl Decodable for DynState {
384 fn consensus_decode_partial<R: std::io::Read>(
385 reader: &mut R,
386 decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
387 ) -> Result<Self, fedimint_core::encoding::DecodeError> {
388 let module_id =
389 fedimint_core::core::ModuleInstanceId::consensus_decode_partial(reader, decoders)?;
390 decoders
391 .get_expect(module_id)
392 .decode_partial(reader, module_id, decoders)
393 }
394}
395
396impl DynState {
397 pub fn visualization(&self, indent: &str) -> String {
399 let mut s = String::new();
400 self.fmt_visualization(&mut s, indent)
401 .expect("fmt::Write to String never fails");
402 s
403 }
404
405 pub fn is_terminal(
407 &self,
408 context: &DynContext,
409 global_context: &DynGlobalClientContext,
410 ) -> bool {
411 self.transitions(context, global_context).is_empty()
412 }
413}
414
415#[derive(Debug)]
416pub struct OperationState<S> {
417 pub operation_id: OperationId,
418 pub state: S,
419}
420
421impl<S> State for OperationState<S>
424where
425 S: State,
426{
427 type ModuleContext = S::ModuleContext;
428
429 fn transitions(
430 &self,
431 context: &Self::ModuleContext,
432 global_context: &DynGlobalClientContext,
433 ) -> Vec<StateTransition<Self>> {
434 let transitions: Vec<StateTransition<OperationState<S>>> = self
435 .state
436 .transitions(context, global_context)
437 .into_iter()
438 .map(
439 |StateTransition {
440 trigger,
441 transition,
442 }| {
443 let op_transition: StateTransitionFunction<Self> =
444 Arc::new(move |dbtx, value, op_state| {
445 let transition = transition.clone();
446 Box::pin(async move {
447 let state = transition(dbtx, value, op_state.state).await;
448 OperationState {
449 operation_id: op_state.operation_id,
450 state,
451 }
452 })
453 });
454
455 StateTransition {
456 trigger,
457 transition: op_transition,
458 }
459 },
460 )
461 .collect();
462 transitions
463 }
464
465 fn operation_id(&self) -> OperationId {
466 self.operation_id
467 }
468
469 fn fmt_visualization(&self, f: &mut dyn std::fmt::Write, indent: &str) -> std::fmt::Result {
470 self.state.fmt_visualization(f, indent)
471 }
472}
473
474impl<S> IntoDynInstance for OperationState<S>
477where
478 S: State,
479{
480 type DynType = DynState;
481
482 fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
483 DynState::from_typed(instance_id, self)
484 }
485}
486
487impl<S> Encodable for OperationState<S>
488where
489 S: State,
490{
491 fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<(), Error> {
492 self.operation_id.consensus_encode(writer)?;
493 self.state.consensus_encode(writer)?;
494 Ok(())
495 }
496}
497
498impl<S> Decodable for OperationState<S>
499where
500 S: State,
501{
502 fn consensus_decode_partial<R: Read>(
503 read: &mut R,
504 modules: &ModuleDecoderRegistry,
505 ) -> Result<Self, DecodeError> {
506 let operation_id = OperationId::consensus_decode_partial(read, modules)?;
507 let state = S::consensus_decode_partial(read, modules)?;
508
509 Ok(OperationState {
510 operation_id,
511 state,
512 })
513 }
514}
515
516impl<S> PartialEq for OperationState<S>
518where
519 S: State,
520{
521 fn eq(&self, other: &Self) -> bool {
522 self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
523 }
524}
525
526impl<S> Eq for OperationState<S> where S: State {}
527
528impl<S> hash::Hash for OperationState<S>
529where
530 S: hash::Hash,
531{
532 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
533 self.operation_id.hash(hasher);
534 self.state.hash(hasher);
535 }
536}
537
538impl<S> Clone for OperationState<S>
539where
540 S: State,
541{
542 fn clone(&self) -> Self {
543 OperationState {
544 operation_id: self.operation_id,
545 state: self.state.clone(),
546 }
547 }
548}