fedimint_server/net/
p2p_connection.rs

1use std::io::Cursor;
2
3use anyhow::Context;
4use async_trait::async_trait;
5use bytes::Bytes;
6use fedimint_core::encoding::{Decodable, Encodable};
7use futures::{SinkExt, StreamExt};
8use iroh::endpoint::Connection;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use tokio::net::TcpStream;
12use tokio_rustls::TlsStream;
13use tokio_util::codec::{Framed, LengthDelimitedCodec};
14
15pub type DynP2PConnection<M> = Box<dyn IP2PConnection<M>>;
16
17#[async_trait]
18pub trait IP2PConnection<M>: Send + 'static {
19    async fn send(&mut self, message: M) -> anyhow::Result<()>;
20
21    async fn receive(&mut self) -> anyhow::Result<M>;
22
23    fn into_dyn(self) -> DynP2PConnection<M>
24    where
25        Self: Sized,
26    {
27        Box::new(self)
28    }
29}
30
31#[async_trait]
32impl<M> IP2PConnection<M> for Framed<TlsStream<TcpStream>, LengthDelimitedCodec>
33where
34    M: Encodable + Decodable + Serialize + DeserializeOwned + Send + 'static,
35{
36    async fn send(&mut self, message: M) -> anyhow::Result<()> {
37        let mut bytes = Vec::new();
38
39        bincode::serialize_into(&mut bytes, &message)?;
40
41        SinkExt::send(self, Bytes::from_owner(bytes)).await?;
42
43        Ok(())
44    }
45
46    async fn receive(&mut self) -> anyhow::Result<M> {
47        Ok(bincode::deserialize_from(Cursor::new(
48            &self.next().await.context("Framed stream is closed")??,
49        ))?)
50    }
51}
52
53#[async_trait]
54impl<M> IP2PConnection<M> for Connection
55where
56    M: Serialize + DeserializeOwned + Send + 'static,
57{
58    async fn send(&mut self, message: M) -> anyhow::Result<()> {
59        let mut bytes = Vec::new();
60
61        bincode::serialize_into(&mut bytes, &message)?;
62
63        let mut sink = self.open_uni().await?;
64
65        sink.write_all(&bytes).await?;
66
67        sink.finish()?;
68
69        Ok(())
70    }
71
72    async fn receive(&mut self) -> anyhow::Result<M> {
73        Ok(bincode::deserialize_from(Cursor::new(
74            &self.accept_uni().await?.read_to_end(1_000_000_000).await?,
75        ))?)
76    }
77}