fedimint_portalloc/data/
dto.rs

1use std::collections::BTreeMap;
2use std::net::{TcpListener, UdpSocket};
3
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace, warn};
6
7/// The lowest port number to try. Ports below 10k are typically used by normal
8/// software, increasing chance they would get in a way.
9const LOW: u16 = 10000;
10
11// The highest port number to try. Ports above 32k are typically ephmeral,
12// increasing a chance of random conflicts after port was already tried.
13const HIGH: u16 = 32000;
14
15const LOG_PORT_ALLOC: &str = "port-alloc";
16
17#[derive(Serialize, Deserialize, Debug, Clone, Default)]
18#[serde(rename_all = "kebab-case")]
19struct RangeData {
20    /// Port range size.
21    size: u16,
22
23    /// Unix timestamp when this range expires.
24    expires: UnixTimestamp,
25}
26
27type UnixTimestamp = u64;
28
29fn default_next() -> u16 {
30    LOW
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34#[serde(rename_all = "kebab-case")]
35pub struct RootData {
36    /// Next port to try.
37    #[serde(default = "default_next")]
38    next: u16,
39
40    /// Map of port ranges. For each range, the key is the first port in the
41    /// range and the range size and expiration time are stored in the value.
42    keys: BTreeMap<u16, RangeData>,
43}
44
45impl Default for RootData {
46    fn default() -> Self {
47        Self {
48            next: LOW,
49            keys: Default::default(),
50        }
51    }
52}
53
54impl RootData {
55    pub fn get_free_port_range(&mut self, range_size: u16) -> u16 {
56        trace!(target: LOG_PORT_ALLOC, range_size, "Looking for port");
57
58        self.reclaim();
59
60        let mut base_port: u16 = self.next;
61        'retry: loop {
62            trace!(target: LOG_PORT_ALLOC, base_port, range_size, "Checking a port");
63            if base_port > HIGH {
64                self.reclaim();
65                base_port = LOW;
66            }
67            let range = base_port..base_port + range_size;
68            if let Some(next_port) = self.contains(range.clone()) {
69                warn!(
70                    base_port,
71                    range_size,
72                    "Could not use a port (already reserved). Will try a different range."
73                );
74                base_port = next_port;
75                continue 'retry;
76            }
77
78            for port in range.clone() {
79                match (
80                    TcpListener::bind(("127.0.0.1", port)),
81                    UdpSocket::bind(("127.0.0.1", port)),
82                ) {
83                    (Err(error), _) | (_, Err(error)) => {
84                        warn!(
85                            ?error,
86                            port, "Could not use a port. Will try a different range"
87                        );
88                        base_port = port + 1;
89                        continue 'retry;
90                    }
91                    (Ok(tcp), Ok(udp)) => (tcp, udp),
92                };
93            }
94
95            self.insert(range);
96            debug!(target: LOG_PORT_ALLOC, base_port, range_size, "Allocated port range");
97            return base_port;
98        }
99    }
100
101    /// Remove expired entries from the map.
102    fn reclaim(&mut self) {
103        let now = Self::now_ts();
104        self.keys.retain(|_k, v| now < v.expires);
105    }
106
107    /// Check if `range` conflicts with anything already reserved
108    ///
109    /// If it does return next address after the range that conflicted.
110    fn contains(&self, range: std::ops::Range<u16>) -> Option<u16> {
111        self.keys.range(..range.end).next_back().and_then(|(k, v)| {
112            let start = *k;
113            let end = start + v.size;
114
115            if start < range.end && range.start < end {
116                Some(end)
117            } else {
118                None
119            }
120        })
121    }
122
123    fn insert(&mut self, range: std::ops::Range<u16>) {
124        const ALLOCATION_TIME_SECS: u64 = 120;
125
126        // The caller gets some time actually start using the port (`bind`),
127        // to prevent other callers from re-using it. This could typically be
128        // much shorter, as portalloc will not only respect the allocation,
129        // but also try to bind before using a given port range. But for tests
130        // that temporarily release ports (e.g. restarts, failure simulations, etc.),
131        // there's a chance that this can expire and another tests snatches the test,
132        // so better to keep it around the time a longest test can take.
133
134        assert!(self.contains(range.clone()).is_none());
135        self.keys.insert(
136            range.start,
137            RangeData {
138                size: range.len() as u16,
139                expires: Self::now_ts() + ALLOCATION_TIME_SECS,
140            },
141        );
142        self.next = range.end;
143    }
144
145    fn now_ts() -> UnixTimestamp {
146        fedimint_core::time::duration_since_epoch().as_secs()
147    }
148}
149
150#[test]
151fn root_data_sanity() {
152    let mut r = RootData::default();
153
154    r.insert(2..4);
155    r.insert(6..8);
156    r.insert(100..108);
157    assert_eq!(r.contains(0..2), None);
158    assert_eq!(r.contains(0..3), Some(4));
159    assert_eq!(r.contains(2..4), Some(4));
160    assert_eq!(r.contains(3..4), Some(4));
161    assert_eq!(r.contains(3..5), Some(4));
162    assert_eq!(r.contains(4..6), None);
163    assert_eq!(r.contains(0..10), Some(8));
164    assert_eq!(r.contains(6..10), Some(8));
165    assert_eq!(r.contains(7..8), Some(8));
166    assert_eq!(r.contains(8..10), None);
167}