fedimint_portalloc/data/
dto.rs

1use std::collections::BTreeMap;
2use std::net::{TcpListener, UdpSocket};
3
4use fedimint_core::util::FmtCompact as _;
5use serde::{Deserialize, Serialize};
6use tracing::{debug, trace, warn};
7
8/// The lowest port number to try. Ports below 10k are typically used by normal
9/// software, increasing chance they would get in a way.
10const LOW: u16 = 10000;
11
12// The highest port number to try. Ports above 32k are typically ephmeral,
13// increasing a chance of random conflicts after port was already tried.
14const HIGH: u16 = 32000;
15
16const LOG_PORT_ALLOC: &str = "port-alloc";
17
18#[derive(Serialize, Deserialize, Debug, Clone, Default)]
19#[serde(rename_all = "kebab-case")]
20struct RangeData {
21    /// Port range size.
22    size: u16,
23
24    /// Unix timestamp when this range expires.
25    expires: UnixTimestamp,
26}
27
28type UnixTimestamp = u64;
29
30fn default_next() -> u16 {
31    LOW
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
35#[serde(rename_all = "kebab-case")]
36pub struct RootData {
37    /// Next port to try.
38    #[serde(default = "default_next")]
39    next: u16,
40
41    /// Map of port ranges. For each range, the key is the first port in the
42    /// range and the range size and expiration time are stored in the value.
43    keys: BTreeMap<u16, RangeData>,
44}
45
46impl Default for RootData {
47    fn default() -> Self {
48        Self {
49            next: LOW,
50            keys: Default::default(),
51        }
52    }
53}
54
55impl RootData {
56    pub fn get_free_port_range(&mut self, range_size: u16) -> u16 {
57        trace!(target: LOG_PORT_ALLOC, range_size, "Looking for port");
58
59        self.reclaim();
60
61        let mut base_port: u16 = self.next;
62        'retry: loop {
63            trace!(target: LOG_PORT_ALLOC, base_port, range_size, "Checking a port");
64            if base_port > HIGH {
65                self.reclaim();
66                base_port = LOW;
67            }
68            let range = base_port..base_port + range_size;
69            if let Some(next_port) = self.contains(range.clone()) {
70                warn!(
71                    base_port,
72                    range_size,
73                    "Could not use a port (already reserved). Will try a different range."
74                );
75                base_port = next_port;
76                continue 'retry;
77            }
78
79            for port in range.clone() {
80                match (
81                    TcpListener::bind(("127.0.0.1", port)),
82                    UdpSocket::bind(("127.0.0.1", port)),
83                ) {
84                    (Err(err), _) | (_, Err(err)) => {
85                        warn!(
86                            err = %err.fmt_compact(),
87                            port, "Could not use a port. Will try a different range"
88                        );
89                        base_port = port + 1;
90                        continue 'retry;
91                    }
92                    (Ok(tcp), Ok(udp)) => (tcp, udp),
93                };
94            }
95
96            self.insert(range);
97            debug!(target: LOG_PORT_ALLOC, base_port, range_size, "Allocated port range");
98            return base_port;
99        }
100    }
101
102    /// Remove expired entries from the map.
103    fn reclaim(&mut self) {
104        let now = Self::now_ts();
105        self.keys.retain(|_k, v| now < v.expires);
106    }
107
108    /// Check if `range` conflicts with anything already reserved
109    ///
110    /// If it does return next address after the range that conflicted.
111    fn contains(&self, range: std::ops::Range<u16>) -> Option<u16> {
112        self.keys.range(..range.end).next_back().and_then(|(k, v)| {
113            let start = *k;
114            let end = start + v.size;
115
116            if start < range.end && range.start < end {
117                Some(end)
118            } else {
119                None
120            }
121        })
122    }
123
124    fn insert(&mut self, range: std::ops::Range<u16>) {
125        const ALLOCATION_TIME_SECS: u64 = 120;
126
127        // The caller gets some time actually start using the port (`bind`),
128        // to prevent other callers from re-using it. This could typically be
129        // much shorter, as portalloc will not only respect the allocation,
130        // but also try to bind before using a given port range. But for tests
131        // that temporarily release ports (e.g. restarts, failure simulations, etc.),
132        // there's a chance that this can expire and another tests snatches the test,
133        // so better to keep it around the time a longest test can take.
134
135        assert!(self.contains(range.clone()).is_none());
136        self.keys.insert(
137            range.start,
138            RangeData {
139                size: range.len() as u16,
140                expires: Self::now_ts() + ALLOCATION_TIME_SECS,
141            },
142        );
143        self.next = range.end;
144    }
145
146    fn now_ts() -> UnixTimestamp {
147        fedimint_core::time::duration_since_epoch().as_secs()
148    }
149}
150
151#[test]
152fn root_data_sanity() {
153    let mut r = RootData::default();
154
155    r.insert(2..4);
156    r.insert(6..8);
157    r.insert(100..108);
158    assert_eq!(r.contains(0..2), None);
159    assert_eq!(r.contains(0..3), Some(4));
160    assert_eq!(r.contains(2..4), Some(4));
161    assert_eq!(r.contains(3..4), Some(4));
162    assert_eq!(r.contains(3..5), Some(4));
163    assert_eq!(r.contains(4..6), None);
164    assert_eq!(r.contains(0..10), Some(8));
165    assert_eq!(r.contains(6..10), Some(8));
166    assert_eq!(r.contains(7..8), Some(8));
167    assert_eq!(r.contains(8..10), None);
168}