use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
use async_trait::async_trait;
use fedimint_core::net::peers::{IMuxPeerConnections, PeerConnections};
use fedimint_core::runtime::spawn;
use fedimint_core::task::{Cancellable, Cancelled};
use fedimint_core::PeerId;
use fedimint_logging::LOG_NET_PEER;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
use tracing::{debug, warn};
pub type ModuleId = String;
pub type ModuleIdRef<'a> = &'a str;
pub const MAX_PEER_OUT_OF_ORDER_MESSAGES: u64 = 10000;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ModuleMultiplexed<MuxKey, Msg> {
pub key: MuxKey,
pub msg: Msg,
}
struct ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
msgs: HashMap<MuxKey, VecDeque<(PeerId, Msg)>>,
callbacks: HashMap<MuxKey, VecDeque<oneshot::Sender<(PeerId, Msg)>>>,
peer_counts: HashMap<PeerId, u64>,
}
impl<MuxKey, Msg> Default for ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
fn default() -> Self {
Self {
msgs: HashMap::new(),
callbacks: HashMap::new(),
peer_counts: HashMap::new(),
}
}
}
#[derive(Clone)]
pub struct PeerConnectionMultiplexer<MuxKey, Msg> {
send_requests_tx: Sender<(Vec<PeerId>, MuxKey, Msg)>,
receive_callbacks_tx: Sender<Callback<MuxKey, Msg>>,
peer_bans_tx: Sender<PeerId>,
}
type Callback<MuxKey, Msg> = (MuxKey, oneshot::Sender<(PeerId, Msg)>);
impl<MuxKey, Msg> PeerConnectionMultiplexer<MuxKey, Msg>
where
Msg: Serialize + DeserializeOwned + Unpin + Send + Debug + 'static,
MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone + 'static,
{
pub fn new(connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>) -> Self {
let (send_requests_tx, send_requests_rx) = channel(1000);
let (receive_callbacks_tx, receive_callbacks_rx) = channel(1000);
let (peer_bans_tx, peer_bans_rx) = channel(1000);
spawn(
"peer connection multiplexer",
Self::run(
connections,
ModuleMultiplexerOutOfOrder::default(),
send_requests_rx,
receive_callbacks_rx,
peer_bans_rx,
),
);
Self {
send_requests_tx,
receive_callbacks_tx,
peer_bans_tx,
}
}
async fn run(
mut connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>,
mut out_of_order: ModuleMultiplexerOutOfOrder<MuxKey, Msg>,
mut send_requests_rx: Receiver<(Vec<PeerId>, MuxKey, Msg)>,
mut receive_callbacks_rx: Receiver<Callback<MuxKey, Msg>>,
mut peer_bans_rx: Receiver<PeerId>,
) -> Cancellable<()> {
loop {
let mut key_inserted: Option<MuxKey> = None;
tokio::select! {
send_request = send_requests_rx.recv() => {
let (peers, key, msg) = send_request.ok_or(Cancelled)?;
connections.send(&peers, ModuleMultiplexed { key, msg }).await?;
}
peer_ban = peer_bans_rx.recv() => {
let peer = peer_ban.ok_or(Cancelled)?;
connections.ban_peer(peer).await;
}
receive_callback = receive_callbacks_rx.recv() => {
let (key, callback) = receive_callback.ok_or(Cancelled)?;
out_of_order.callbacks.entry(key.clone()).or_default().push_back(callback);
key_inserted = Some(key);
}
receive = connections.receive() => {
let (peer, ModuleMultiplexed { key, msg }) = receive?;
let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
if *peer_pending > MAX_PEER_OUT_OF_ORDER_MESSAGES {
warn!(
target: LOG_NET_PEER,
"Peer {peer} has {peer_pending} pending messages. Dropping new message."
);
} else {
*peer_pending += 1;
out_of_order.msgs.entry(key.clone()).or_default().push_back((peer, msg));
key_inserted = Some(key);
}
}
}
if let Some(key) = key_inserted {
let callbacks = out_of_order.callbacks.entry(key.clone()).or_default();
let msgs = out_of_order.msgs.entry(key.clone()).or_default();
if !callbacks.is_empty() && !msgs.is_empty() {
let callback = callbacks.pop_front().expect("checked");
let (peer, msg) = msgs.pop_front().expect("checked");
let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
*peer_pending -= 1;
callback.send((peer, msg)).map_err(|_| Cancelled)?;
}
}
}
}
}
#[async_trait]
impl<MuxKey, Msg> IMuxPeerConnections<MuxKey, Msg> for PeerConnectionMultiplexer<MuxKey, Msg>
where
Msg: Serialize + DeserializeOwned + Unpin + Send + Debug,
MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone,
{
async fn send(&self, peers: &[PeerId], key: MuxKey, msg: Msg) -> Cancellable<()> {
debug!("Sending to {peers:?}/{key:?}, {msg:?}");
self.send_requests_tx
.send((peers.to_vec(), key, msg))
.await
.map_err(|_e| Cancelled)
}
async fn receive(&self, key: MuxKey) -> Cancellable<(PeerId, Msg)> {
let (callback_tx, callback_rx) = oneshot::channel();
self.receive_callbacks_tx
.send((key, callback_tx))
.await
.map_err(|_e| Cancelled)?;
callback_rx.await.map_err(|_e| Cancelled)
}
async fn ban_peer(&self, peer: PeerId) {
let _ = self.peer_bans_tx.send(peer).await;
}
}
#[cfg(test)]
pub mod test {
use std::time::Duration;
use fedimint_core::net::peers::fake::make_fake_peer_connection;
use fedimint_core::net::peers::IMuxPeerConnections;
use fedimint_core::task::{self, TaskGroup};
use fedimint_core::PeerId;
use rand::rngs::OsRng;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use crate::multiplexed::PeerConnectionMultiplexer;
#[test_log::test(tokio::test)]
async fn test_multiplexer() {
const NUM_MODULES: usize = 128;
const NUM_MSGS_PER_MODULE: usize = 128;
const NUM_REPEAT_TEST: usize = 10;
for _ in 0..NUM_REPEAT_TEST {
let task_group = TaskGroup::new();
let task_handle = task_group.make_handle();
let peer1 = PeerId::from(0);
let peer2 = PeerId::from(1);
let (conn1, conn2) = make_fake_peer_connection(peer1, peer2, 1000, task_handle.clone());
let (conn1, conn2) = (
PeerConnectionMultiplexer::new(conn1).into_dyn(),
PeerConnectionMultiplexer::new(conn2).into_dyn(),
);
let mut modules: Vec<_> = (0..NUM_MODULES).collect();
modules.shuffle(&mut thread_rng());
for mux_key in modules.clone() {
let conn1 = conn1.clone();
let task_handle = task_handle.clone();
task_group.spawn(format!("sender-{mux_key}"), move |_| async move {
for msg_i in 0..NUM_MSGS_PER_MODULE {
if OsRng.gen() {
task::sleep(Duration::from_millis(2)).await;
}
if task_handle.is_shutting_down() {
break;
}
conn1.send(&[peer2], mux_key, msg_i).await.unwrap();
}
});
}
modules.shuffle(&mut thread_rng());
for mux_key in modules.clone() {
let conn2 = conn2.clone();
task_group.spawn(format!("receiver-{mux_key}"), move |_| async move {
for msg_i in 0..NUM_MSGS_PER_MODULE {
if OsRng.gen() {
task::sleep(Duration::from_millis(1)).await;
}
assert_eq!(conn2.receive(mux_key).await.unwrap(), (peer1, msg_i));
}
});
}
task_group.join_all(None).await.expect("no failures");
}
}
}