fedimint_core/util/
mod.rs

1pub mod backoff_util;
2/// Copied from `tokio_stream` 0.1.12 to use our optional Send bounds
3pub mod broadcaststream;
4mod error;
5pub mod update_merge;
6
7use std::convert::Infallible;
8use std::fmt::{Debug, Display, Formatter};
9use std::future::Future;
10use std::hash::Hash;
11use std::io::Write;
12use std::path::Path;
13use std::pin::Pin;
14use std::str::FromStr;
15use std::{fs, io};
16
17use anyhow::format_err;
18pub use error::*;
19use fedimint_logging::LOG_CORE;
20use futures::StreamExt;
21use serde::{Deserialize, Serialize};
22use tokio::io::AsyncWriteExt;
23use tracing::{Instrument, Span, debug, warn};
24use url::{Host, ParseError, Url};
25
26use crate::task::MaybeSend;
27use crate::{apply, async_trait_maybe_send, maybe_add_send, runtime};
28
29/// Future that is `Send` unless targeting WASM
30pub type BoxFuture<'a, T> = Pin<Box<maybe_add_send!(dyn Future<Output = T> + 'a)>>;
31
32/// Stream that is `Send` unless targeting WASM
33pub type BoxStream<'a, T> = Pin<Box<maybe_add_send!(dyn futures::Stream<Item = T> + 'a)>>;
34
35#[apply(async_trait_maybe_send!)]
36pub trait NextOrPending {
37    type Output;
38
39    async fn next_or_pending(&mut self) -> Self::Output;
40
41    async fn ok(&mut self) -> anyhow::Result<Self::Output>;
42}
43
44#[apply(async_trait_maybe_send!)]
45impl<S> NextOrPending for S
46where
47    S: futures::Stream + Unpin + MaybeSend,
48    S::Item: MaybeSend,
49{
50    type Output = S::Item;
51
52    /// Waits for the next item in a stream. If the stream is closed while
53    /// waiting, returns an error.  Useful when expecting a stream to progress.
54    async fn ok(&mut self) -> anyhow::Result<Self::Output> {
55        self.next()
56            .await
57            .map_or_else(|| Err(format_err!("Stream was unexpectedly closed")), Ok)
58    }
59
60    /// Waits for the next item in a stream. If the stream is closed while
61    /// waiting the future will be pending forever. This is useful in cases
62    /// where the future will be cancelled by shutdown logic anyway and handling
63    /// each place where a stream may terminate would be too much trouble.
64    async fn next_or_pending(&mut self) -> Self::Output {
65        if let Some(item) = self.next().await {
66            item
67        } else {
68            debug!(target: LOG_CORE, "Stream ended in next_or_pending, pending forever to avoid throwing an error on shutdown");
69            std::future::pending().await
70        }
71    }
72}
73
74// TODO: make fully RFC1738 conformant
75/// Wrapper for `Url` that only prints the scheme, domain, port and path portion
76/// of a `Url` in its `Display` implementation.
77///
78/// This is useful to hide private
79/// information like user names and passwords in logs or UIs.
80///
81/// The output is not fully RFC1738 conformant but good enough for our current
82/// purposes.
83#[derive(Hash, Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
84// nosemgrep: ban-raw-url
85pub struct SafeUrl(Url);
86
87impl SafeUrl {
88    pub fn parse(url_str: &str) -> Result<Self, ParseError> {
89        Url::parse(url_str).map(SafeUrl)
90    }
91
92    /// Warning: This removes the safety.
93    // nosemgrep: ban-raw-url
94    pub fn to_unsafe(self) -> Url {
95        self.0
96    }
97
98    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
99    pub fn set_username(&mut self, username: &str) -> Result<(), ()> {
100        self.0.set_username(username)
101    }
102
103    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
104    pub fn set_password(&mut self, password: Option<&str>) -> Result<(), ()> {
105        self.0.set_password(password)
106    }
107
108    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
109    pub fn without_auth(&self) -> Result<Self, ()> {
110        let mut url = self.clone();
111
112        url.set_username("").and_then(|()| url.set_password(None))?;
113
114        Ok(url)
115    }
116
117    pub fn host(&self) -> Option<Host<&str>> {
118        self.0.host()
119    }
120    pub fn host_str(&self) -> Option<&str> {
121        self.0.host_str()
122    }
123    pub fn scheme(&self) -> &str {
124        self.0.scheme()
125    }
126    pub fn port(&self) -> Option<u16> {
127        self.0.port()
128    }
129    pub fn port_or_known_default(&self) -> Option<u16> {
130        self.0.port_or_known_default()
131    }
132    pub fn path(&self) -> &str {
133        self.0.path()
134    }
135    /// Warning: This will expose username & password if present.
136    pub fn as_str(&self) -> &str {
137        self.0.as_str()
138    }
139    pub fn username(&self) -> &str {
140        self.0.username()
141    }
142    pub fn password(&self) -> Option<&str> {
143        self.0.password()
144    }
145    pub fn join(&self, input: &str) -> Result<Self, ParseError> {
146        self.0.join(input).map(SafeUrl)
147    }
148
149    // It can be removed to use `is_onion_address()` implementation,
150    // once https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2214 lands.
151    #[allow(clippy::case_sensitive_file_extension_comparisons)]
152    pub fn is_onion_address(&self) -> bool {
153        let host = self.host_str().unwrap_or_default();
154
155        host.ends_with(".onion")
156    }
157
158    pub fn fragment(&self) -> Option<&str> {
159        self.0.fragment()
160    }
161
162    pub fn set_fragment(&mut self, arg: Option<&str>) {
163        self.0.set_fragment(arg);
164    }
165}
166
167impl Display for SafeUrl {
168    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
169        write!(f, "{}://", self.0.scheme())?;
170
171        if !self.0.username().is_empty() {
172            write!(f, "REDACTEDUSER")?;
173
174            if self.0.password().is_some() {
175                write!(f, ":REDACTEDPASS")?;
176            }
177
178            write!(f, "@")?;
179        }
180
181        if let Some(host) = self.0.host_str() {
182            write!(f, "{host}")?;
183        }
184
185        if let Some(port) = self.0.port() {
186            write!(f, ":{port}")?;
187        }
188
189        write!(f, "{}", self.0.path())?;
190
191        Ok(())
192    }
193}
194
195impl Debug for SafeUrl {
196    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
197        write!(f, "SafeUrl(")?;
198        Display::fmt(self, f)?;
199        write!(f, ")")?;
200        Ok(())
201    }
202}
203
204impl From<Url> for SafeUrl {
205    fn from(u: Url) -> Self {
206        Self(u)
207    }
208}
209
210impl FromStr for SafeUrl {
211    type Err = ParseError;
212
213    #[inline]
214    fn from_str(input: &str) -> Result<Self, ParseError> {
215        Self::parse(input)
216    }
217}
218
219/// Write out a new file (like [`std::fs::write`] but fails if file already
220/// exists)
221#[cfg(not(target_family = "wasm"))]
222pub fn write_new<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
223    fs::File::options()
224        .write(true)
225        .create_new(true)
226        .open(path)?
227        .write_all(contents.as_ref())
228}
229
230#[cfg(not(target_family = "wasm"))]
231pub fn write_overwrite<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
232    fs::File::options()
233        .write(true)
234        .create(true)
235        .truncate(true)
236        .open(path)?
237        .write_all(contents.as_ref())
238}
239
240#[cfg(not(target_family = "wasm"))]
241pub async fn write_overwrite_async<P: AsRef<Path>, C: AsRef<[u8]>>(
242    path: P,
243    contents: C,
244) -> io::Result<()> {
245    tokio::fs::OpenOptions::new()
246        .write(true)
247        .create(true)
248        .truncate(true)
249        .open(path)
250        .await?
251        .write_all(contents.as_ref())
252        .await
253}
254
255#[cfg(not(target_family = "wasm"))]
256pub async fn write_new_async<P: AsRef<Path>, C: AsRef<[u8]>>(
257    path: P,
258    contents: C,
259) -> io::Result<()> {
260    tokio::fs::OpenOptions::new()
261        .write(true)
262        .create_new(true)
263        .open(path)
264        .await?
265        .write_all(contents.as_ref())
266        .await
267}
268
269#[derive(Debug, Clone)]
270pub struct Spanned<T> {
271    value: T,
272    span: Span,
273}
274
275impl<T> Spanned<T> {
276    pub async fn new<F: Future<Output = T>>(span: Span, make: F) -> Self {
277        Self::try_new::<Infallible, _>(span, async { Ok(make.await) })
278            .await
279            .unwrap()
280    }
281
282    pub async fn try_new<E, F: Future<Output = Result<T, E>>>(
283        span: Span,
284        make: F,
285    ) -> Result<Self, E> {
286        let span2 = span.clone();
287        async {
288            Ok(Self {
289                value: make.await?,
290                span: span2,
291            })
292        }
293        .instrument(span)
294        .await
295    }
296
297    pub fn borrow(&self) -> Spanned<&T> {
298        Spanned {
299            value: &self.value,
300            span: self.span.clone(),
301        }
302    }
303
304    pub fn map<U>(self, map: impl Fn(T) -> U) -> Spanned<U> {
305        Spanned {
306            value: map(self.value),
307            span: self.span,
308        }
309    }
310
311    pub fn borrow_mut(&mut self) -> Spanned<&mut T> {
312        Spanned {
313            value: &mut self.value,
314            span: self.span.clone(),
315        }
316    }
317
318    pub fn with_sync<O, F: FnOnce(T) -> O>(self, f: F) -> O {
319        let _g = self.span.enter();
320        f(self.value)
321    }
322
323    pub async fn with<Fut: Future, F: FnOnce(T) -> Fut>(self, f: F) -> Fut::Output {
324        async { f(self.value).await }.instrument(self.span).await
325    }
326
327    pub fn span(&self) -> Span {
328        self.span.clone()
329    }
330
331    pub fn value(&self) -> &T {
332        &self.value
333    }
334
335    pub fn value_mut(&mut self) -> &mut T {
336        &mut self.value
337    }
338
339    pub fn into_value(self) -> T {
340        self.value
341    }
342}
343
344/// For CLIs, detects `version-hash` as a single argument, prints the provided
345/// version hash, then exits the process.
346pub fn handle_version_hash_command(version_hash: &str) {
347    let mut args = std::env::args();
348    if let Some(ref arg) = args.nth(1) {
349        if arg.as_str() == "version-hash" {
350            println!("{version_hash}");
351            std::process::exit(0);
352        }
353    }
354}
355
356/// Run the supplied closure `op_fn` until it succeeds. Frequency and number of
357/// retries is determined by the specified strategy.
358///
359/// ```
360/// use std::time::Duration;
361///
362/// use fedimint_core::util::{backoff_util, retry};
363/// # tokio_test::block_on(async {
364/// retry(
365///     "Gateway balance after swap".to_string(),
366///     backoff_util::background_backoff(),
367///     || async {
368///         // Fallible network calls …
369///         Ok(())
370///     },
371/// )
372/// .await
373/// .expect("never fails");
374/// # });
375/// ```
376///
377/// # Returns
378///
379/// - If the closure runs successfully, the result is immediately returned
380/// - If the closure did not run successfully for `max_attempts` times, the
381///   error of the closure is returned
382pub async fn retry<F, Fut, T>(
383    op_name: impl Into<String>,
384    strategy: impl backoff_util::Backoff,
385    op_fn: F,
386) -> Result<T, anyhow::Error>
387where
388    F: Fn() -> Fut,
389    Fut: Future<Output = Result<T, anyhow::Error>>,
390{
391    let mut strategy = strategy;
392    let op_name = op_name.into();
393    let mut attempts: u64 = 0;
394    loop {
395        attempts += 1;
396        match op_fn().await {
397            Ok(result) => return Ok(result),
398            Err(err) => {
399                if let Some(interval) = strategy.next() {
400                    // run closure op_fn again
401                    debug!(
402                        target: LOG_CORE,
403                        err = %err.fmt_compact_anyhow(),
404                        %attempts,
405                        interval = interval.as_secs(),
406                        "{} failed, retrying",
407                        op_name,
408                    );
409                    runtime::sleep(interval).await;
410                } else {
411                    warn!(
412                        target: LOG_CORE,
413                        err = %err.fmt_compact_anyhow(),
414                        %attempts,
415                        "{} failed",
416                        op_name,
417                    );
418                    return Err(err);
419                }
420            }
421        }
422    }
423}
424
425/// Computes the median from a slice of sorted `u64`s
426pub fn get_median(vals: &[u64]) -> Option<u64> {
427    if vals.is_empty() {
428        return None;
429    }
430    let len = vals.len();
431    let mid = len / 2;
432
433    if len % 2 == 0 {
434        Some(u64::midpoint(vals[mid - 1], vals[mid]))
435    } else {
436        Some(vals[mid])
437    }
438}
439
440/// Computes the average of the given `u64` slice.
441pub fn get_average(vals: &[u64]) -> Option<u64> {
442    if vals.is_empty() {
443        return None;
444    }
445
446    let sum: u64 = vals.iter().sum();
447    Some(sum / vals.len() as u64)
448}
449
450#[cfg(test)]
451mod tests;