fedimint_core/db/
mem_impl.rs

1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use anyhow::Result;
6use futures::{StreamExt, stream};
7use imbl::OrdMap;
8use macro_rules_attribute::apply;
9
10use super::{
11    IDatabaseTransactionOps, IDatabaseTransactionOpsCore, IRawDatabase, IRawDatabaseTransaction,
12};
13use crate::async_trait_maybe_send;
14use crate::db::PrefixStream;
15
16#[derive(Debug, Default)]
17pub struct DatabaseInsertOperation {
18    pub key: Vec<u8>,
19    pub value: Vec<u8>,
20    pub old_value: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Default)]
24pub struct DatabaseDeleteOperation {
25    pub key: Vec<u8>,
26    pub old_value: Option<Vec<u8>>,
27}
28
29#[derive(Debug)]
30pub enum DatabaseOperation {
31    Insert(DatabaseInsertOperation),
32    Delete(DatabaseDeleteOperation),
33}
34
35#[derive(Default)]
36pub struct MemDatabase {
37    data: std::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
38}
39
40impl fmt::Debug for MemDatabase {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_fmt(format_args!("MemDatabase {{}}",))
43    }
44}
45pub struct MemTransaction<'a> {
46    operations: Vec<DatabaseOperation>,
47    tx_data: OrdMap<Vec<u8>, Vec<u8>>,
48    db: &'a MemDatabase,
49    savepoint: OrdMap<Vec<u8>, Vec<u8>>,
50    num_pending_operations: usize,
51    num_savepoint_operations: usize,
52}
53
54impl fmt::Debug for MemTransaction<'_> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.write_fmt(format_args!(
57            "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
58            self.db,
59            self.operations.len(),
60            self.tx_data.len(),
61            self.savepoint.len(),
62            self.num_pending_operations,
63            self.num_savepoint_operations,
64        ))
65    }
66}
67
68#[derive(Debug, Eq, PartialEq)]
69pub struct DummyError;
70
71impl MemDatabase {
72    pub fn new() -> Self {
73        Self::default()
74    }
75}
76
77#[apply(async_trait_maybe_send!)]
78impl IRawDatabase for MemDatabase {
79    type Transaction<'a> = MemTransaction<'a>;
80    async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
81        let db_copy = self.data.read().expect("Poisoned rwlock").clone();
82        let mut memtx = MemTransaction {
83            operations: Vec::new(),
84            tx_data: db_copy.clone(),
85            db: self,
86            savepoint: db_copy,
87            num_pending_operations: 0,
88            num_savepoint_operations: 0,
89        };
90
91        memtx.set_tx_savepoint().await.expect("can't fail");
92        memtx
93    }
94
95    fn checkpoint(&self, _backup_path: &Path) -> Result<()> {
96        Ok(())
97    }
98}
99
100// In-memory database transaction should only be used for test code and never
101// for production as it doesn't properly implement MVCC
102#[apply(async_trait_maybe_send!)]
103impl IDatabaseTransactionOpsCore for MemTransaction<'_> {
104    async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
105        // Insert data from copy so we can read our own writes
106        let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
107        self.operations
108            .push(DatabaseOperation::Insert(DatabaseInsertOperation {
109                key: key.to_vec(),
110                value: value.to_owned(),
111                old_value: old_value.clone(),
112            }));
113        self.num_pending_operations += 1;
114        Ok(old_value)
115    }
116
117    async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
118        Ok(self.tx_data.get(key).cloned())
119    }
120
121    async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
122        // Remove data from copy so we can read our own writes
123        let old_value = self.tx_data.remove(&key.to_vec());
124        self.operations
125            .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
126                key: key.to_vec(),
127                old_value: old_value.clone(),
128            }));
129        self.num_pending_operations += 1;
130        Ok(old_value)
131    }
132
133    async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
134        let keys = self
135            .raw_find_by_prefix(key_prefix)
136            .await?
137            .map(|kv| kv.0)
138            .collect::<Vec<_>>()
139            .await;
140        for key in keys {
141            self.raw_remove_entry(key.as_slice()).await?;
142        }
143        Ok(())
144    }
145
146    async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
147        let data = self
148            .tx_data
149            .range((key_prefix.to_vec())..)
150            .take_while(|(key, _)| key.starts_with(key_prefix))
151            .map(|(key, value)| (key.clone(), value.clone()))
152            .collect::<Vec<_>>();
153        Ok(Box::pin(stream::iter(data)))
154    }
155
156    async fn raw_find_by_prefix_sorted_descending(
157        &mut self,
158        key_prefix: &[u8],
159    ) -> Result<PrefixStream<'_>> {
160        let mut data = self
161            .tx_data
162            .range((key_prefix.to_vec())..)
163            .take_while(|(key, _)| key.starts_with(key_prefix))
164            .map(|(key, value)| (key.clone(), value.clone()))
165            .collect::<Vec<_>>();
166        data.sort_by(|a, b| a.cmp(b).reverse());
167
168        Ok(Box::pin(stream::iter(data)))
169    }
170
171    async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
172        let data = self
173            .tx_data
174            .range(Range {
175                start: range.start.to_vec(),
176                end: range.end.to_vec(),
177            })
178            .map(|(key, value)| (key.clone(), value.clone()))
179            .collect::<Vec<_>>();
180        Ok(Box::pin(stream::iter(data)))
181    }
182}
183
184#[apply(async_trait_maybe_send!)]
185impl IDatabaseTransactionOps for MemTransaction<'_> {
186    async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
187        self.tx_data = self.savepoint.clone();
188
189        // Remove any pending operations beyond the savepoint
190        let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
191        for _i in 0..removed_ops {
192            self.operations.pop();
193        }
194
195        Ok(())
196    }
197
198    async fn set_tx_savepoint(&mut self) -> Result<()> {
199        self.savepoint = self.tx_data.clone();
200        self.num_savepoint_operations = self.num_pending_operations;
201        Ok(())
202    }
203}
204
205#[apply(async_trait_maybe_send!)]
206impl IRawDatabaseTransaction for MemTransaction<'_> {
207    #[allow(clippy::significant_drop_tightening)]
208    async fn commit_tx(self) -> Result<()> {
209        let mut data = self.db.data.write().expect("Poisoned rwlock");
210        let mut data_copy = data.clone();
211        for op in self.operations {
212            match op {
213                DatabaseOperation::Insert(insert_op) => {
214                    anyhow::ensure!(
215                        data_copy.insert(insert_op.key, insert_op.value) == insert_op.old_value,
216                        "write-write conflict"
217                    );
218                }
219                DatabaseOperation::Delete(delete_op) => {
220                    anyhow::ensure!(
221                        data_copy.remove(&delete_op.key) == delete_op.old_value,
222                        "write-write conflict"
223                    );
224                }
225            }
226        }
227        *data = data_copy;
228        Ok(())
229    }
230}
231
232#[cfg(test)]
233mod tests;