fedimint_portalloc/data/
dto.rs1use std::collections::BTreeMap;
2use std::net::{TcpListener, UdpSocket};
3
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace, warn};
6
7const LOW: u16 = 10000;
10
11const HIGH: u16 = 32000;
14
15const LOG_PORT_ALLOC: &str = "port-alloc";
16
17#[derive(Serialize, Deserialize, Debug, Clone, Default)]
18#[serde(rename_all = "kebab-case")]
19struct RangeData {
20 size: u16,
22
23 expires: UnixTimestamp,
25}
26
27type UnixTimestamp = u64;
28
29fn default_next() -> u16 {
30 LOW
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34#[serde(rename_all = "kebab-case")]
35pub struct RootData {
36 #[serde(default = "default_next")]
38 next: u16,
39
40 keys: BTreeMap<u16, RangeData>,
43}
44
45impl Default for RootData {
46 fn default() -> Self {
47 Self {
48 next: LOW,
49 keys: Default::default(),
50 }
51 }
52}
53
54impl RootData {
55 pub fn get_free_port_range(&mut self, range_size: u16) -> u16 {
56 trace!(target: LOG_PORT_ALLOC, range_size, "Looking for port");
57
58 self.reclaim();
59
60 let mut base_port: u16 = self.next;
61 'retry: loop {
62 trace!(target: LOG_PORT_ALLOC, base_port, range_size, "Checking a port");
63 if base_port > HIGH {
64 self.reclaim();
65 base_port = LOW;
66 }
67 let range = base_port..base_port + range_size;
68 if let Some(next_port) = self.contains(range.clone()) {
69 warn!(
70 base_port,
71 range_size,
72 "Could not use a port (already reserved). Will try a different range."
73 );
74 base_port = next_port;
75 continue 'retry;
76 }
77
78 for port in range.clone() {
79 match (
80 TcpListener::bind(("127.0.0.1", port)),
81 UdpSocket::bind(("127.0.0.1", port)),
82 ) {
83 (Err(error), _) | (_, Err(error)) => {
84 warn!(
85 ?error,
86 port, "Could not use a port. Will try a different range"
87 );
88 base_port = port + 1;
89 continue 'retry;
90 }
91 (Ok(tcp), Ok(udp)) => (tcp, udp),
92 };
93 }
94
95 self.insert(range);
96 debug!(target: LOG_PORT_ALLOC, base_port, range_size, "Allocated port range");
97 return base_port;
98 }
99 }
100
101 fn reclaim(&mut self) {
103 let now = Self::now_ts();
104 self.keys.retain(|_k, v| now < v.expires);
105 }
106
107 fn contains(&self, range: std::ops::Range<u16>) -> Option<u16> {
111 self.keys.range(..range.end).next_back().and_then(|(k, v)| {
112 let start = *k;
113 let end = start + v.size;
114
115 if start < range.end && range.start < end {
116 Some(end)
117 } else {
118 None
119 }
120 })
121 }
122
123 fn insert(&mut self, range: std::ops::Range<u16>) {
124 const ALLOCATION_TIME_SECS: u64 = 120;
125
126 assert!(self.contains(range.clone()).is_none());
135 self.keys.insert(
136 range.start,
137 RangeData {
138 size: range.len() as u16,
139 expires: Self::now_ts() + ALLOCATION_TIME_SECS,
140 },
141 );
142 self.next = range.end;
143 }
144
145 fn now_ts() -> UnixTimestamp {
146 fedimint_core::time::duration_since_epoch().as_secs()
147 }
148}
149
150#[test]
151fn root_data_sanity() {
152 let mut r = RootData::default();
153
154 r.insert(2..4);
155 r.insert(6..8);
156 r.insert(100..108);
157 assert_eq!(r.contains(0..2), None);
158 assert_eq!(r.contains(0..3), Some(4));
159 assert_eq!(r.contains(2..4), Some(4));
160 assert_eq!(r.contains(3..4), Some(4));
161 assert_eq!(r.contains(3..5), Some(4));
162 assert_eq!(r.contains(4..6), None);
163 assert_eq!(r.contains(0..10), Some(8));
164 assert_eq!(r.contains(6..10), Some(8));
165 assert_eq!(r.contains(7..8), Some(8));
166 assert_eq!(r.contains(8..10), None);
167}