Skip to main content

fedimint_core/util/
mod.rs

1pub mod backoff_util;
2/// Copied from `tokio_stream` 0.1.12 to use our optional Send bounds
3pub mod broadcaststream;
4pub mod update_merge;
5
6use std::convert::Infallible;
7use std::fmt::{Debug, Display, Formatter};
8use std::future::Future;
9use std::hash::Hash;
10use std::io::Write;
11use std::path::Path;
12use std::pin::Pin;
13use std::str::FromStr;
14use std::sync::LazyLock;
15use std::{fs, io};
16
17use anyhow::format_err;
18use fedimint_logging::LOG_CORE;
19pub use fedimint_util_error::*;
20use futures::StreamExt;
21use serde::{Deserialize, Serialize};
22use tokio::io::AsyncWriteExt;
23use tracing::{Instrument, Span, debug, warn};
24use url::{Host, ParseError, Url};
25
26use crate::envs::{FM_DEBUG_SHOW_SECRETS_ENV, is_env_var_set};
27use crate::task::MaybeSend;
28use crate::{apply, async_trait_maybe_send, maybe_add_send, runtime};
29
30/// Future that is `Send` unless targeting WASM
31pub type BoxFuture<'a, T> = Pin<Box<maybe_add_send!(dyn Future<Output = T> + 'a)>>;
32
33/// Stream that is `Send` unless targeting WASM
34pub type BoxStream<'a, T> = Pin<Box<maybe_add_send!(dyn futures::Stream<Item = T> + 'a)>>;
35
36#[apply(async_trait_maybe_send!)]
37pub trait NextOrPending {
38    type Output;
39
40    async fn next_or_pending(&mut self) -> Self::Output;
41
42    async fn ok(&mut self) -> anyhow::Result<Self::Output>;
43}
44
45#[apply(async_trait_maybe_send!)]
46impl<S> NextOrPending for S
47where
48    S: futures::Stream + Unpin + MaybeSend,
49    S::Item: MaybeSend,
50{
51    type Output = S::Item;
52
53    /// Waits for the next item in a stream. If the stream is closed while
54    /// waiting, returns an error.  Useful when expecting a stream to progress.
55    async fn ok(&mut self) -> anyhow::Result<Self::Output> {
56        self.next()
57            .await
58            .map_or_else(|| Err(format_err!("Stream was unexpectedly closed")), Ok)
59    }
60
61    /// Waits for the next item in a stream. If the stream is closed while
62    /// waiting the future will be pending forever. This is useful in cases
63    /// where the future will be cancelled by shutdown logic anyway and handling
64    /// each place where a stream may terminate would be too much trouble.
65    async fn next_or_pending(&mut self) -> Self::Output {
66        if let Some(item) = self.next().await {
67            item
68        } else {
69            debug!(target: LOG_CORE, "Stream ended in next_or_pending, pending forever to avoid throwing an error on shutdown");
70            std::future::pending().await
71        }
72    }
73}
74
75// TODO: make fully RFC1738 conformant
76/// Wrapper for `Url` that only prints the scheme, domain, port and path portion
77/// of a `Url` in its `Display` implementation.
78///
79/// This is useful to hide private
80/// information like user names and passwords in logs or UIs.
81///
82/// The output is not fully RFC1738 conformant but good enough for our current
83/// purposes.
84#[derive(Hash, Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
85// nosemgrep: ban-raw-url
86pub struct SafeUrl(Url);
87
88impl SafeUrl {
89    pub fn parse(url_str: &str) -> Result<Self, ParseError> {
90        Url::parse(url_str).map(SafeUrl)
91    }
92
93    /// Warning: This removes the safety.
94    // nosemgrep: ban-raw-url
95    pub fn to_unsafe(self) -> Url {
96        self.0
97    }
98
99    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
100    pub fn set_username(&mut self, username: &str) -> Result<(), ()> {
101        self.0.set_username(username)
102    }
103
104    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
105    pub fn set_password(&mut self, password: Option<&str>) -> Result<(), ()> {
106        self.0.set_password(password)
107    }
108
109    #[allow(clippy::result_unit_err)] // just copying `url`'s API here
110    pub fn without_auth(&self) -> Result<Self, ()> {
111        let mut url = self.clone();
112
113        url.set_username("").and_then(|()| url.set_password(None))?;
114
115        Ok(url)
116    }
117
118    pub fn host(&self) -> Option<Host<&str>> {
119        self.0.host()
120    }
121    pub fn host_str(&self) -> Option<&str> {
122        self.0.host_str()
123    }
124    pub fn scheme(&self) -> &str {
125        self.0.scheme()
126    }
127    pub fn port(&self) -> Option<u16> {
128        self.0.port()
129    }
130    pub fn port_or_known_default(&self) -> Option<u16> {
131        self.0.port_or_known_default()
132    }
133    pub fn path(&self) -> &str {
134        self.0.path()
135    }
136    /// Warning: This will expose username & password if present.
137    pub fn as_str(&self) -> &str {
138        self.0.as_str()
139    }
140    pub fn username(&self) -> &str {
141        self.0.username()
142    }
143    pub fn password(&self) -> Option<&str> {
144        self.0.password()
145    }
146    pub fn join(&self, input: &str) -> Result<Self, ParseError> {
147        self.0.join(input).map(SafeUrl)
148    }
149
150    /// Append a relative path, ensuring exactly one `/` between
151    /// the base and the path segment.
152    ///
153    /// Unlike `Url::join` (RFC 3986), this never drops path
154    /// segments from the base — it always appends.
155    pub fn join_path(&self, path: &str) -> Self {
156        let base = self.to_string();
157        let base = base.trim_end_matches('/');
158        let path = path.trim_start_matches('/');
159        Self::parse(&format!("{base}/{path}"))
160            .expect("appending a relative path to a valid URL should produce a valid URL")
161    }
162
163    // It can be removed to use `is_onion_address()` implementation,
164    // once https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2214 lands.
165    #[allow(clippy::case_sensitive_file_extension_comparisons)]
166    pub fn is_onion_address(&self) -> bool {
167        let host = self.host_str().unwrap_or_default();
168
169        host.ends_with(".onion")
170    }
171
172    pub fn fragment(&self) -> Option<&str> {
173        self.0.fragment()
174    }
175
176    pub fn set_fragment(&mut self, arg: Option<&str>) {
177        self.0.set_fragment(arg);
178    }
179}
180
181static SHOW_SECRETS: LazyLock<bool> = LazyLock::new(|| {
182    let enable = is_env_var_set(FM_DEBUG_SHOW_SECRETS_ENV);
183
184    if enable {
185        warn!(target: LOG_CORE, "{} enabled. Please don't use in production.", FM_DEBUG_SHOW_SECRETS_ENV);
186    }
187
188    enable
189});
190
191impl Display for SafeUrl {
192    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193        write!(f, "{}://", self.0.scheme())?;
194
195        if !self.0.username().is_empty() {
196            let show_secrets = *SHOW_SECRETS;
197            if show_secrets {
198                write!(f, "{}", self.0.username())?;
199            } else {
200                write!(f, "REDACTEDUSER")?;
201            }
202
203            if self.0.password().is_some() {
204                if show_secrets {
205                    write!(
206                        f,
207                        ":{}",
208                        self.0.password().expect("Just checked it's checked")
209                    )?;
210                } else {
211                    write!(f, ":REDACTEDPASS")?;
212                }
213            }
214
215            write!(f, "@")?;
216        }
217
218        if let Some(host) = self.0.host_str() {
219            write!(f, "{host}")?;
220        }
221
222        if let Some(port) = self.0.port() {
223            write!(f, ":{port}")?;
224        }
225
226        write!(f, "{}", self.0.path())?;
227
228        Ok(())
229    }
230}
231
232impl Debug for SafeUrl {
233    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234        write!(f, "SafeUrl(")?;
235        Display::fmt(self, f)?;
236        write!(f, ")")?;
237        Ok(())
238    }
239}
240
241impl From<Url> for SafeUrl {
242    fn from(u: Url) -> Self {
243        Self(u)
244    }
245}
246
247impl FromStr for SafeUrl {
248    type Err = ParseError;
249
250    #[inline]
251    fn from_str(input: &str) -> Result<Self, ParseError> {
252        Self::parse(input)
253    }
254}
255
256/// Write out a new file (like [`std::fs::write`] but fails if file already
257/// exists)
258#[cfg(not(target_family = "wasm"))]
259pub fn write_new<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
260    let mut file = fs::File::options()
261        .write(true)
262        .create_new(true)
263        .open(path)?;
264    file.write_all(contents.as_ref())?;
265    file.sync_all()?;
266    Ok(())
267}
268
269#[cfg(not(target_family = "wasm"))]
270pub fn write_overwrite<P: AsRef<Path>, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> {
271    fs::File::options()
272        .write(true)
273        .create(true)
274        .truncate(true)
275        .open(path)?
276        .write_all(contents.as_ref())
277}
278
279#[cfg(not(target_family = "wasm"))]
280pub async fn write_overwrite_async<P: AsRef<Path>, C: AsRef<[u8]>>(
281    path: P,
282    contents: C,
283) -> io::Result<()> {
284    tokio::fs::OpenOptions::new()
285        .write(true)
286        .create(true)
287        .truncate(true)
288        .open(path)
289        .await?
290        .write_all(contents.as_ref())
291        .await
292}
293
294#[cfg(not(target_family = "wasm"))]
295pub async fn write_new_async<P: AsRef<Path>, C: AsRef<[u8]>>(
296    path: P,
297    contents: C,
298) -> io::Result<()> {
299    tokio::fs::OpenOptions::new()
300        .write(true)
301        .create_new(true)
302        .open(path)
303        .await?
304        .write_all(contents.as_ref())
305        .await
306}
307
308#[derive(Debug, Clone)]
309pub struct Spanned<T> {
310    value: T,
311    span: Span,
312}
313
314impl<T> Spanned<T> {
315    pub async fn new<F: Future<Output = T>>(span: Span, make: F) -> Self {
316        Self::try_new::<Infallible, _>(span, async { Ok(make.await) })
317            .await
318            .unwrap()
319    }
320
321    pub async fn try_new<E, F: Future<Output = Result<T, E>>>(
322        span: Span,
323        make: F,
324    ) -> Result<Self, E> {
325        let span2 = span.clone();
326        async {
327            Ok(Self {
328                value: make.await?,
329                span: span2,
330            })
331        }
332        .instrument(span)
333        .await
334    }
335
336    pub fn borrow(&self) -> Spanned<&T> {
337        Spanned {
338            value: &self.value,
339            span: self.span.clone(),
340        }
341    }
342
343    pub fn map<U>(self, map: impl Fn(T) -> U) -> Spanned<U> {
344        Spanned {
345            value: map(self.value),
346            span: self.span,
347        }
348    }
349
350    pub fn borrow_mut(&mut self) -> Spanned<&mut T> {
351        Spanned {
352            value: &mut self.value,
353            span: self.span.clone(),
354        }
355    }
356
357    pub fn with_sync<O, F: FnOnce(T) -> O>(self, f: F) -> O {
358        let _g = self.span.enter();
359        f(self.value)
360    }
361
362    pub async fn with<Fut: Future, F: FnOnce(T) -> Fut>(self, f: F) -> Fut::Output {
363        async { f(self.value).await }.instrument(self.span).await
364    }
365
366    pub fn span(&self) -> Span {
367        self.span.clone()
368    }
369
370    pub fn value(&self) -> &T {
371        &self.value
372    }
373
374    pub fn value_mut(&mut self) -> &mut T {
375        &mut self.value
376    }
377
378    pub fn into_value(self) -> T {
379        self.value
380    }
381}
382
383/// For CLIs, detects `version-hash` as a single argument, prints the provided
384/// version hash, then exits the process.
385pub fn handle_version_hash_command(version_hash: &str) {
386    let mut args = std::env::args();
387    if let Some(ref arg) = args.nth(1)
388        && arg.as_str() == "version-hash"
389    {
390        println!("{version_hash}");
391        std::process::exit(0);
392    }
393}
394
395/// Run the supplied closure `op_fn` until it succeeds. Frequency and number of
396/// retries is determined by the specified strategy.
397///
398/// ```
399/// use std::time::Duration;
400///
401/// use fedimint_core::util::{backoff_util, retry};
402/// # tokio_test::block_on(async {
403/// retry(
404///     "Gateway balance after swap".to_string(),
405///     backoff_util::background_backoff(),
406///     || async {
407///         // Fallible network calls …
408///         Ok(())
409///     },
410/// )
411/// .await
412/// .expect("never fails");
413/// # });
414/// ```
415///
416/// # Returns
417///
418/// - If the closure runs successfully, the result is immediately returned
419/// - If the closure did not run successfully for `max_attempts` times, the
420///   error of the closure is returned
421pub async fn retry<F, Fut, T>(
422    op_name: impl Into<String>,
423    strategy: impl backoff_util::Backoff,
424    op_fn: F,
425) -> Result<T, anyhow::Error>
426where
427    F: Fn() -> Fut,
428    Fut: Future<Output = Result<T, anyhow::Error>>,
429{
430    let mut strategy = strategy;
431    let op_name = op_name.into();
432    let mut attempts: u64 = 0;
433    loop {
434        attempts += 1;
435        match op_fn().await {
436            Ok(result) => return Ok(result),
437            Err(err) => {
438                if let Some(interval) = strategy.next() {
439                    // run closure op_fn again
440                    debug!(
441                        target: LOG_CORE,
442                        err = %err.fmt_compact_anyhow(),
443                        %attempts,
444                        interval = interval.as_secs(),
445                        "{} failed, retrying",
446                        op_name,
447                    );
448                    runtime::sleep(interval).await;
449                } else {
450                    warn!(
451                        target: LOG_CORE,
452                        err = %err.fmt_compact_anyhow(),
453                        %attempts,
454                        "{} failed",
455                        op_name,
456                    );
457                    return Err(err);
458                }
459            }
460        }
461    }
462}
463
464/// Computes the median from a slice of sorted `u64`s
465pub fn get_median(vals: &[u64]) -> Option<u64> {
466    if vals.is_empty() {
467        return None;
468    }
469    let len = vals.len();
470    let mid = len / 2;
471
472    if len.is_multiple_of(2) {
473        Some(u64::midpoint(vals[mid - 1], vals[mid]))
474    } else {
475        Some(vals[mid])
476    }
477}
478
479/// Computes the average of the given `u64` slice.
480pub fn get_average(vals: &[u64]) -> Option<u64> {
481    if vals.is_empty() {
482        return None;
483    }
484
485    let sum: u64 = vals.iter().sum();
486    Some(sum / vals.len() as u64)
487}
488
489#[cfg(test)]
490mod tests;