use std::any::Any;
use std::fmt::Debug;
use std::future::Future;
use std::hash;
use std::io::{Error, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, ModuleKind, OperationId};
use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
use fedimint_core::module::registry::ModuleDecoderRegistry;
use fedimint_core::task::{MaybeSend, MaybeSync};
use fedimint_core::util::BoxFuture;
use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
use crate::sm::ClientSMDatabaseTransaction;
use crate::DynGlobalClientContext;
pub trait State:
Debug
+ Clone
+ Eq
+ PartialEq
+ std::hash::Hash
+ Encodable
+ Decodable
+ MaybeSend
+ MaybeSync
+ 'static
{
type ModuleContext: Context;
fn transitions(
&self,
context: &Self::ModuleContext,
global_context: &DynGlobalClientContext,
) -> Vec<StateTransition<Self>>;
fn operation_id(&self) -> OperationId;
}
pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
fn transitions(
&self,
context: &DynContext,
global_context: &DynGlobalClientContext,
) -> Vec<StateTransition<DynState>>;
fn operation_id(&self) -> OperationId;
fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
}
pub trait IContext: Debug {
fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
fn module_kind(&self) -> Option<ModuleKind>;
}
module_plugin_dyn_newtype_define! {
#[derive(Clone)]
pub DynContext(Arc<IContext>)
}
pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {
const KIND: Option<ModuleKind>;
}
impl<T> IContext for T
where
T: Context + 'static + MaybeSend + MaybeSync,
{
fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
self
}
fn module_kind(&self) -> Option<ModuleKind> {
T::KIND
}
}
type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
pub(super) type StateTransitionFunction<S> = Arc<
maybe_add_send_sync!(
dyn for<'a> Fn(
&'a mut ClientSMDatabaseTransaction<'_, '_>,
serde_json::Value,
S,
) -> BoxFuture<'a, S>
),
>;
pub struct StateTransition<S> {
pub trigger: TriggerFuture,
pub transition: StateTransitionFunction<S>,
}
impl<S> StateTransition<S> {
pub fn new<V, Trigger, TransitionFn>(
trigger: Trigger,
transition: TransitionFn,
) -> StateTransition<S>
where
S: MaybeSend + MaybeSync + Clone + 'static,
V: serde::Serialize + serde::de::DeserializeOwned + Send,
Trigger: Future<Output = V> + MaybeSend + 'static,
TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
+ MaybeSend
+ MaybeSync
+ Clone
+ 'static,
{
StateTransition {
trigger: Box::pin(async {
let val = trigger.await;
serde_json::to_value(val).expect("Value could not be serialized")
}),
transition: Arc::new(move |dbtx, val, state| {
let transition = transition.clone();
Box::pin(async move {
let typed_val: V = serde_json::from_value(val)
.expect("Deserialize trigger return value failed");
transition(dbtx, typed_val, state.clone()).await
})
}),
}
}
}
impl<T> IState for T
where
T: State,
{
fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
self
}
fn transitions(
&self,
context: &DynContext,
global_context: &DynGlobalClientContext,
) -> Vec<StateTransition<DynState>> {
<T as State>::transitions(
self,
context.as_any().downcast_ref().expect("Wrong module"),
global_context,
)
.into_iter()
.map(|st| StateTransition {
trigger: st.trigger,
transition: Arc::new(
move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
let transition = st.transition.clone();
Box::pin(async move {
let new_state = transition(
dbtx,
val,
state
.as_any()
.downcast_ref::<T>()
.expect("Wrong module")
.clone(),
)
.await;
DynState::from_typed(state.module_instance_id(), new_state)
})
},
),
})
.collect()
}
fn operation_id(&self) -> OperationId {
<T as State>::operation_id(self)
}
fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
}
fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
let other: &T = other
.as_any()
.downcast_ref()
.expect("Type is ensured in previous step");
self == other
}
fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
self.hash(&mut hasher);
}
}
pub struct DynState(
Box<maybe_add_send_sync!(dyn IState + 'static)>,
ModuleInstanceId,
);
impl IState for DynState {
fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
(**self).as_any()
}
fn transitions(
&self,
context: &DynContext,
global_context: &DynGlobalClientContext,
) -> Vec<StateTransition<DynState>> {
(**self).transitions(context, global_context)
}
fn operation_id(&self) -> OperationId {
(**self).operation_id()
}
fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
(**self).clone(module_instance_id)
}
fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
(**self).erased_eq_no_instance_id(other)
}
fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher) {
(**self).erased_hash_no_instance_id(hasher);
}
}
impl IntoDynInstance for DynState {
type DynType = DynState;
fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
assert_eq!(instance_id, self.1);
self
}
}
impl std::ops::Deref for DynState {
type Target = maybe_add_send_sync!(dyn IState + 'static);
fn deref(&self) -> &<Self as std::ops::Deref>::Target {
&*self.0
}
}
impl hash::Hash for DynState {
fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
self.1.hash(hasher);
self.0.erased_hash_no_instance_id(hasher);
}
}
impl DynState {
pub fn module_instance_id(&self) -> ModuleInstanceId {
self.1
}
pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
where
I: IState + 'static,
{
Self(Box::new(typed), module_instance_id)
}
pub fn from_parts(
module_instance_id: ::fedimint_core::core::ModuleInstanceId,
dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
) -> Self {
Self(dynbox, module_instance_id)
}
}
impl std::fmt::Debug for DynState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
impl std::ops::DerefMut for DynState {
fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
&mut *self.0
}
}
impl Clone for DynState {
fn clone(&self) -> Self {
self.0.clone(self.1)
}
}
impl PartialEq for DynState {
fn eq(&self, other: &Self) -> bool {
if self.1 != other.1 {
return false;
}
self.erased_eq_no_instance_id(other)
}
}
impl Eq for DynState {}
impl Encodable for DynState {
fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
self.1.consensus_encode(writer)?;
self.0.consensus_encode_dyn(writer)
}
}
impl Decodable for DynState {
fn consensus_decode<R: std::io::Read>(
reader: &mut R,
decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
) -> Result<Self, fedimint_core::encoding::DecodeError> {
let module_id = fedimint_core::core::ModuleInstanceId::consensus_decode(reader, decoders)?;
decoders
.get_expect(module_id)
.decode_partial(reader, module_id, decoders)
}
}
impl DynState {
pub fn is_terminal(
&self,
context: &DynContext,
global_context: &DynGlobalClientContext,
) -> bool {
self.transitions(context, global_context).is_empty()
}
}
#[derive(Debug)]
pub struct OperationState<S> {
pub operation_id: OperationId,
pub state: S,
}
impl<S> State for OperationState<S>
where
S: State,
{
type ModuleContext = S::ModuleContext;
fn transitions(
&self,
context: &Self::ModuleContext,
global_context: &DynGlobalClientContext,
) -> Vec<StateTransition<Self>> {
let transitions: Vec<StateTransition<OperationState<S>>> = self
.state
.transitions(context, global_context)
.into_iter()
.map(
|StateTransition {
trigger,
transition,
}| {
let op_transition: StateTransitionFunction<Self> =
Arc::new(move |dbtx, value, op_state| {
let transition = transition.clone();
Box::pin(async move {
let state = transition(dbtx, value, op_state.state).await;
OperationState {
operation_id: op_state.operation_id,
state,
}
})
});
StateTransition {
trigger,
transition: op_transition,
}
},
)
.collect();
transitions
}
fn operation_id(&self) -> OperationId {
self.operation_id
}
}
impl<S> IntoDynInstance for OperationState<S>
where
S: State,
{
type DynType = DynState;
fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
DynState::from_typed(instance_id, self)
}
}
impl<S> Encodable for OperationState<S>
where
S: State,
{
fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
let mut len = 0;
len += self.operation_id.consensus_encode(writer)?;
len += self.state.consensus_encode(writer)?;
Ok(len)
}
}
impl<S> Decodable for OperationState<S>
where
S: State,
{
fn consensus_decode<R: Read>(
read: &mut R,
modules: &ModuleDecoderRegistry,
) -> Result<Self, DecodeError> {
let operation_id = OperationId::consensus_decode(read, modules)?;
let state = S::consensus_decode(read, modules)?;
Ok(OperationState {
operation_id,
state,
})
}
}
impl<S> PartialEq for OperationState<S>
where
S: State,
{
fn eq(&self, other: &Self) -> bool {
self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
}
}
impl<S> Eq for OperationState<S> where S: State {}
impl<S> hash::Hash for OperationState<S>
where
S: hash::Hash,
{
fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
self.operation_id.hash(hasher);
self.state.hash(hasher);
}
}
impl<S> Clone for OperationState<S>
where
S: State,
{
fn clone(&self) -> Self {
OperationState {
operation_id: self.operation_id,
state: self.state.clone(),
}
}
}