1use std::collections::BTreeMap;
2use std::net::{TcpListener, UdpSocket};
34use serde::{Deserialize, Serialize};
5use tracing::{debug, trace, warn};
67/// The lowest port number to try. Ports below 10k are typically used by normal
8/// software, increasing chance they would get in a way.
9const LOW: u16 = 10000;
1011// The highest port number to try. Ports above 32k are typically ephmeral,
12// increasing a chance of random conflicts after port was already tried.
13const HIGH: u16 = 32000;
1415const LOG_PORT_ALLOC: &str = "port-alloc";
1617#[derive(Serialize, Deserialize, Debug, Clone, Default)]
18#[serde(rename_all = "kebab-case")]
19struct RangeData {
20/// Port range size.
21size: u16,
2223/// Unix timestamp when this range expires.
24expires: UnixTimestamp,
25}
2627type UnixTimestamp = u64;
2829fn default_next() -> u16 {
30 LOW
31}
3233#[derive(Serialize, Deserialize, Debug, Clone)]
34#[serde(rename_all = "kebab-case")]
35pub struct RootData {
36/// Next port to try.
37#[serde(default = "default_next")]
38next: u16,
3940/// Map of port ranges. For each range, the key is the first port in the
41 /// range and the range size and expiration time are stored in the value.
42keys: BTreeMap<u16, RangeData>,
43}
4445impl Default for RootData {
46fn default() -> Self {
47Self {
48 next: LOW,
49 keys: Default::default(),
50 }
51 }
52}
5354impl RootData {
55pub fn get_free_port_range(&mut self, range_size: u16) -> u16 {
56trace!(target: LOG_PORT_ALLOC, range_size, "Looking for port");
5758self.reclaim();
5960let mut base_port: u16 = self.next;
61'retry: loop {
62trace!(target: LOG_PORT_ALLOC, base_port, range_size, "Checking a port");
63if base_port > HIGH {
64self.reclaim();
65 base_port = LOW;
66 }
67let range = base_port..base_port + range_size;
68if let Some(next_port) = self.contains(range.clone()) {
69warn!(
70 base_port,
71 range_size,
72"Could not use a port (already reserved). Will try a different range."
73);
74 base_port = next_port;
75continue 'retry;
76 }
7778for port in range.clone() {
79match (
80 TcpListener::bind(("127.0.0.1", port)),
81 UdpSocket::bind(("127.0.0.1", port)),
82 ) {
83 (Err(error), _) | (_, Err(error)) => {
84warn!(
85?error,
86 port, "Could not use a port. Will try a different range"
87);
88 base_port = port + 1;
89continue 'retry;
90 }
91 (Ok(tcp), Ok(udp)) => (tcp, udp),
92 };
93 }
9495self.insert(range);
96debug!(target: LOG_PORT_ALLOC, base_port, range_size, "Allocated port range");
97return base_port;
98 }
99 }
100101/// Remove expired entries from the map.
102fn reclaim(&mut self) {
103let now = Self::now_ts();
104self.keys.retain(|_k, v| now < v.expires);
105 }
106107/// Check if `range` conflicts with anything already reserved
108 ///
109 /// If it does return next address after the range that conflicted.
110fn contains(&self, range: std::ops::Range<u16>) -> Option<u16> {
111self.keys.range(..range.end).next_back().and_then(|(k, v)| {
112let start = *k;
113let end = start + v.size;
114115if start < range.end && range.start < end {
116Some(end)
117 } else {
118None
119}
120 })
121 }
122123fn insert(&mut self, range: std::ops::Range<u16>) {
124const ALLOCATION_TIME_SECS: u64 = 120;
125126// The caller gets some time actually start using the port (`bind`),
127 // to prevent other callers from re-using it. This could typically be
128 // much shorter, as portalloc will not only respect the allocation,
129 // but also try to bind before using a given port range. But for tests
130 // that temporarily release ports (e.g. restarts, failure simulations, etc.),
131 // there's a chance that this can expire and another tests snatches the test,
132 // so better to keep it around the time a longest test can take.
133134assert!(self.contains(range.clone()).is_none());
135self.keys.insert(
136 range.start,
137 RangeData {
138 size: range.len() as u16,
139 expires: Self::now_ts() + ALLOCATION_TIME_SECS,
140 },
141 );
142self.next = range.end;
143 }
144145fn now_ts() -> UnixTimestamp {
146 fedimint_core::time::duration_since_epoch().as_secs()
147 }
148}
149150#[test]
151fn root_data_sanity() {
152let mut r = RootData::default();
153154 r.insert(2..4);
155 r.insert(6..8);
156 r.insert(100..108);
157assert_eq!(r.contains(0..2), None);
158assert_eq!(r.contains(0..3), Some(4));
159assert_eq!(r.contains(2..4), Some(4));
160assert_eq!(r.contains(3..4), Some(4));
161assert_eq!(r.contains(3..5), Some(4));
162assert_eq!(r.contains(4..6), None);
163assert_eq!(r.contains(0..10), Some(8));
164assert_eq!(r.contains(6..10), Some(8));
165assert_eq!(r.contains(7..8), Some(8));
166assert_eq!(r.contains(8..10), None);
167}