1use std::marker::PhantomData;
2use std::sync::Arc;
34use fedimint_core::core::{ModuleInstanceId, OperationId};
5use fedimint_core::util::BoxStream;
6use fedimint_core::util::broadcaststream::BroadcastStream;
7use fedimint_logging::LOG_CLIENT;
8use futures::StreamExt as _;
9use tracing::{debug, error, trace};
1011use super::{DynState, State};
12use crate::module::FinalClientIface;
13use crate::sm::executor::{ActiveStateKey, InactiveStateKey};
14use crate::sm::{ActiveStateMeta, InactiveStateMeta};
1516/// State transition notifier for a specific module instance that can only
17/// subscribe to transitions belonging to that module
18#[derive(Debug, Clone)]
19pub struct ModuleNotifier<S> {
20 broadcast: tokio::sync::broadcast::Sender<DynState>,
21 module_instance: ModuleInstanceId,
22 client: FinalClientIface,
23/// `S` limits the type of state that can be subscribed to the one
24 /// associated with the module instance
25_pd: PhantomData<S>,
26}
2728impl<S> ModuleNotifier<S>
29where
30S: State,
31{
32pub fn new(
33 broadcast: tokio::sync::broadcast::Sender<DynState>,
34 module_instance: ModuleInstanceId,
35 client: FinalClientIface,
36 ) -> Self {
37Self {
38 broadcast,
39 module_instance,
40 client,
41 _pd: PhantomData,
42 }
43 }
4445// TODO: remove duplicates and order old transitions
46/// Subscribe to state transitions belonging to an operation and module
47 /// (module context contained in struct).
48 ///
49 /// The returned stream will contain all past state transitions that
50 /// happened before the subscription and are read from the database, after
51 /// these the stream will contain all future state transitions. The states
52 /// loaded from the database are not returned in a specific order. There may
53 /// also be duplications.
54pub async fn subscribe(&self, operation_id: OperationId) -> BoxStream<'static, S> {
55let to_typed_state = |state: DynState| {
56 state
57 .as_any()
58 .downcast_ref::<S>()
59 .expect("Tried to subscribe to wrong state type")
60 .clone()
61 };
6263// It's important to start the subscription first and then query the database to
64 // not lose any transitions in the meantime.
65let new_transitions = self.subscribe_all_operations();
6667let client_strong = self.client.get();
68let db_states = {
69let mut dbtx = client_strong.db().begin_transaction_nc().await;
70let active_states = client_strong
71 .read_operation_active_states(operation_id, self.module_instance, &mut dbtx)
72 .await
73.map(|(key, val): (ActiveStateKey, ActiveStateMeta)| {
74 (to_typed_state(key.state), val.created_at)
75 })
76 .collect::<Vec<(S, _)>>()
77 .await;
7879let inactive_states = self
80.client
81 .get()
82 .read_operation_inactive_states(operation_id, self.module_instance, &mut dbtx)
83 .await
84.map(|(key, val): (InactiveStateKey, InactiveStateMeta)| {
85 (to_typed_state(key.state), val.created_at)
86 })
87 .collect::<Vec<(S, _)>>()
88 .await;
8990// FIXME: don't rely on SystemTime for ordering and introduce a state transition
91 // index instead (dpc was right again xD)
92let mut all_states_timed = active_states
93 .into_iter()
94 .chain(inactive_states)
95 .collect::<Vec<(S, _)>>();
96 all_states_timed.sort_by(|(_, t1), (_, t2)| t1.cmp(t2));
97debug!(
98 operation_id = %operation_id.fmt_short(),
99 num = all_states_timed.len(),
100"Returning state transitions from DB for notifier subscription",
101 );
102 all_states_timed
103 .into_iter()
104 .map(|(s, _)| s)
105 .collect::<Vec<S>>()
106 };
107108let new_transitions = new_transitions.filter_map({
109let db_states: Arc<_> = Arc::new(db_states.clone());
110111move |state: S| {
112let db_states = db_states.clone();
113async move {
114if state.operation_id() == operation_id {
115trace!(operation_id = %operation_id.fmt_short(), ?state, "Received state transition notification");
116// Deduplicate events that might have both come from the DB and streamed,
117 // due to subscribing to notifier before querying the DB.
118 //
119 // Note: linear search should be good enough in practice for many reasons.
120 // Eg. states tend to have all the states in the DB, or all streamed "live",
121 // so the overlap here should be minimal.
122 // And we'll rewrite the whole thing anyway and use only db as a reference.
123if db_states.iter().any(|db_s| db_s == &state) {
124debug!(operation_id = %operation_id.fmt_short(), ?state, "Ignoring duplicated event");
125return None;
126 }
127Some(state)
128 } else {
129None
130}
131 }
132 }
133 });
134 Box::pin(futures::stream::iter(db_states).chain(new_transitions))
135 }
136137/// Subscribe to all state transitions belonging to the module instance.
138pub fn subscribe_all_operations(&self) -> BoxStream<'static, S> {
139let module_instance_id = self.module_instance;
140 Box::pin(
141 BroadcastStream::new(self.broadcast.subscribe())
142 .take_while(|res| {
143let cont = if let Err(err) = res {
144error!(target: LOG_CLIENT, ?err, "ModuleNotifier stream stopped on error");
145false
146} else {
147true
148};
149 std::future::ready(cont)
150 })
151 .filter_map(move |res| async move {
152let s = res.expect("We filtered out errors above");
153if s.module_instance_id() == module_instance_id {
154Some(
155 s.as_any()
156 .downcast_ref::<S>()
157 .expect("Tried to subscribe to wrong state type")
158 .clone(),
159 )
160 } else {
161None
162}
163 }),
164 )
165 }
166}