fedimint_core/util/
broadcaststream.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::{Stream, ready};
7use tokio::sync::broadcast::Receiver;
8use tokio::sync::broadcast::error::RecvError;
9
10use crate::task::MaybeSend;
11use crate::util::BoxFuture;
12
13/// A wrapper around [`tokio::sync::broadcast::Receiver`] that implements
14/// [`Stream`].
15///
16/// [`tokio::sync::broadcast::Receiver`]: struct@tokio::sync::broadcast::Receiver
17/// [`Stream`]: trait@futures::Stream
18pub struct BroadcastStream<T> {
19    inner: BoxFuture<'static, (Result<T, RecvError>, Receiver<T>)>,
20}
21
22/// An error returned from the inner stream of a [`BroadcastStream`].
23#[derive(Debug, PartialEq, Eq, Clone)]
24pub enum BroadcastStreamRecvError {
25    /// The receiver lagged too far behind. Attempting to receive again will
26    /// return the oldest message still retained by the channel.
27    ///
28    /// Includes the number of skipped messages.
29    Lagged(u64),
30}
31
32impl fmt::Display for BroadcastStreamRecvError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::Lagged(amt) => write!(f, "channel lagged by {amt}"),
36        }
37    }
38}
39
40impl std::error::Error for BroadcastStreamRecvError {}
41
42async fn make_future<T: Clone>(mut rx: Receiver<T>) -> (Result<T, RecvError>, Receiver<T>) {
43    let result = rx.recv().await;
44    (result, rx)
45}
46
47impl<T: 'static + Clone + MaybeSend> BroadcastStream<T> {
48    /// Create a new `BroadcastStream`.
49    pub fn new(rx: Receiver<T>) -> Self {
50        Self {
51            inner: Box::pin(make_future(rx)),
52        }
53    }
54}
55
56impl<T: 'static + Clone + MaybeSend> Stream for BroadcastStream<T> {
57    type Item = Result<T, BroadcastStreamRecvError>;
58    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        let (result, rx) = ready!(Pin::new(&mut self.inner).poll(cx));
60        self.inner = Box::pin(make_future(rx));
61        match result {
62            Ok(item) => Poll::Ready(Some(Ok(item))),
63            Err(RecvError::Closed) => Poll::Ready(None),
64            Err(RecvError::Lagged(n)) => {
65                Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n))))
66            }
67        }
68    }
69}
70
71impl<T> fmt::Debug for BroadcastStream<T> {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        f.debug_struct("BroadcastStream").finish()
74    }
75}
76
77impl<T: 'static + Clone + MaybeSend> From<Receiver<T>> for BroadcastStream<T> {
78    fn from(recv: Receiver<T>) -> Self {
79        Self::new(recv)
80    }
81}