fedimint_portalloc/data/
dto.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
use std::collections::BTreeMap;
use std::net::TcpListener;

use serde::{Deserialize, Serialize};
use tracing::{debug, trace, warn};

/// The lowest port number to try. Ports below 10k are typically used by normal
/// software, increasing chance they would get in a way.
const LOW: u16 = 10000;

// The highest port number to try. Ports above 32k are typically ephmeral,
// increasing a chance of random conflicts after port was already tried.
const HIGH: u16 = 32000;

const LOG_PORT_ALLOC: &str = "port-alloc";

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(rename_all = "kebab-case")]
struct RangeData {
    /// Port range size.
    size: u16,

    /// Unix timestamp when this range expires.
    expires: UnixTimestamp,
}

type UnixTimestamp = u64;

fn default_next() -> u16 {
    LOW
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct RootData {
    /// Next port to try.
    #[serde(default = "default_next")]
    next: u16,

    /// Map of port ranges. For each range, the key is the first port in the
    /// range and the range size and expiration time are stored in the value.
    keys: BTreeMap<u16, RangeData>,
}

impl Default for RootData {
    fn default() -> Self {
        Self {
            next: LOW,
            keys: Default::default(),
        }
    }
}

impl RootData {
    pub fn get_free_port_range(&mut self, range_size: u16) -> u16 {
        trace!(target: LOG_PORT_ALLOC, range_size, "Looking for port");

        self.reclaim();

        let mut base_port: u16 = self.next;
        'retry: loop {
            trace!(target: LOG_PORT_ALLOC, base_port, range_size, "Checking a port");
            if base_port > HIGH {
                self.reclaim();
                base_port = LOW;
            }
            let range = base_port..base_port + range_size;
            if let Some(next_port) = self.contains(range.clone()) {
                warn!(
                    base_port,
                    range_size,
                    "Could not use a port (already reserved). Will try a different range."
                );
                base_port = next_port;
                continue 'retry;
            }

            for port in range.clone() {
                match TcpListener::bind(("127.0.0.1", port)) {
                    Err(error) => {
                        warn!(
                            ?error,
                            port, "Could not use a port. Will try a different range"
                        );
                        base_port = port + 1;
                        continue 'retry;
                    }
                    Ok(l) => l,
                };
            }

            self.insert(range);
            debug!(target: LOG_PORT_ALLOC, base_port, range_size, "Allocated port range");
            return base_port;
        }
    }

    /// Remove expired entries from the map.
    fn reclaim(&mut self) {
        let now = Self::now_ts();
        self.keys.retain(|_k, v| now < v.expires);
    }

    /// Check if `range` conflicts with anything already reserved
    ///
    /// If it does return next address after the range that conflicted.
    fn contains(&self, range: std::ops::Range<u16>) -> Option<u16> {
        self.keys.range(..range.end).next_back().and_then(|(k, v)| {
            let start = *k;
            let end = start + v.size;

            if start < range.end && range.start < end {
                Some(end)
            } else {
                None
            }
        })
    }

    fn insert(&mut self, range: std::ops::Range<u16>) {
        const ALLOCATION_TIME_SECS: u64 = 120;

        // The caller gets some time actually start using the port (`bind`),
        // to prevent other callers from re-using it. This could typically be
        // much shorter, as portalloc will not only respect the allocation,
        // but also try to bind before using a given port range. But for tests
        // that temporarily release ports (e.g. restarts, failure simulations, etc.),
        // there's a chance that this can expire and another tests snatches the test,
        // so better to keep it around the time a longest test can take.

        assert!(self.contains(range.clone()).is_none());
        self.keys.insert(
            range.start,
            RangeData {
                size: range.len() as u16,
                expires: Self::now_ts() + ALLOCATION_TIME_SECS,
            },
        );
        self.next = range.end;
    }

    fn now_ts() -> UnixTimestamp {
        fedimint_core::time::duration_since_epoch().as_secs()
    }
}

#[test]
fn root_data_sanity() {
    let mut r = RootData::default();

    r.insert(2..4);
    r.insert(6..8);
    r.insert(100..108);
    assert_eq!(r.contains(0..2), None);
    assert_eq!(r.contains(0..3), Some(4));
    assert_eq!(r.contains(2..4), Some(4));
    assert_eq!(r.contains(3..4), Some(4));
    assert_eq!(r.contains(3..5), Some(4));
    assert_eq!(r.contains(4..6), None);
    assert_eq!(r.contains(0..10), Some(8));
    assert_eq!(r.contains(6..10), Some(8));
    assert_eq!(r.contains(7..8), Some(8));
    assert_eq!(r.contains(8..10), None);
}