fedimint_core/db/
mem_impl.rs

1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use futures::{StreamExt, stream};
6use imbl::OrdMap;
7use macro_rules_attribute::apply;
8
9use super::{
10    DatabaseError, DatabaseResult, IDatabaseTransactionOps, IDatabaseTransactionOpsCore,
11    IRawDatabase, IRawDatabaseTransaction,
12};
13use crate::async_trait_maybe_send;
14use crate::db::PrefixStream;
15
16#[derive(Debug, Default)]
17pub struct DatabaseInsertOperation {
18    pub key: Vec<u8>,
19    pub value: Vec<u8>,
20    pub old_value: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Default)]
24pub struct DatabaseDeleteOperation {
25    pub key: Vec<u8>,
26    pub old_value: Option<Vec<u8>>,
27}
28
29#[derive(Debug)]
30pub enum DatabaseOperation {
31    Insert(DatabaseInsertOperation),
32    Delete(DatabaseDeleteOperation),
33}
34
35#[derive(Default)]
36pub struct MemDatabase {
37    data: std::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
38}
39
40impl fmt::Debug for MemDatabase {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_fmt(format_args!("MemDatabase {{}}",))
43    }
44}
45pub struct MemTransaction<'a> {
46    operations: Vec<DatabaseOperation>,
47    tx_data: OrdMap<Vec<u8>, Vec<u8>>,
48    db: &'a MemDatabase,
49    savepoint: OrdMap<Vec<u8>, Vec<u8>>,
50    num_pending_operations: usize,
51    num_savepoint_operations: usize,
52}
53
54impl fmt::Debug for MemTransaction<'_> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.write_fmt(format_args!(
57            "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
58            self.db,
59            self.operations.len(),
60            self.tx_data.len(),
61            self.savepoint.len(),
62            self.num_pending_operations,
63            self.num_savepoint_operations,
64        ))
65    }
66}
67
68#[derive(Debug, Eq, PartialEq)]
69pub struct DummyError;
70
71impl MemDatabase {
72    pub fn new() -> Self {
73        Self::default()
74    }
75}
76
77#[apply(async_trait_maybe_send!)]
78impl IRawDatabase for MemDatabase {
79    type Transaction<'a> = MemTransaction<'a>;
80    async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
81        let db_copy = self.data.read().expect("Poisoned rwlock").clone();
82        let mut memtx = MemTransaction {
83            operations: Vec::new(),
84            tx_data: db_copy.clone(),
85            db: self,
86            savepoint: db_copy,
87            num_pending_operations: 0,
88            num_savepoint_operations: 0,
89        };
90
91        memtx.set_tx_savepoint().await.expect("can't fail");
92        memtx
93    }
94
95    fn checkpoint(&self, _backup_path: &Path) -> DatabaseResult<()> {
96        Ok(())
97    }
98}
99
100// In-memory database transaction should only be used for test code and never
101// for production as it doesn't properly implement MVCC
102#[apply(async_trait_maybe_send!)]
103impl IDatabaseTransactionOpsCore for MemTransaction<'_> {
104    async fn raw_insert_bytes(
105        &mut self,
106        key: &[u8],
107        value: &[u8],
108    ) -> DatabaseResult<Option<Vec<u8>>> {
109        // Insert data from copy so we can read our own writes
110        let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
111        self.operations
112            .push(DatabaseOperation::Insert(DatabaseInsertOperation {
113                key: key.to_vec(),
114                value: value.to_owned(),
115                old_value: old_value.clone(),
116            }));
117        self.num_pending_operations += 1;
118        Ok(old_value)
119    }
120
121    async fn raw_get_bytes(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
122        Ok(self.tx_data.get(key).cloned())
123    }
124
125    async fn raw_remove_entry(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
126        // Remove data from copy so we can read our own writes
127        let old_value = self.tx_data.remove(&key.to_vec());
128        self.operations
129            .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
130                key: key.to_vec(),
131                old_value: old_value.clone(),
132            }));
133        self.num_pending_operations += 1;
134        Ok(old_value)
135    }
136
137    async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<()> {
138        let keys = self
139            .raw_find_by_prefix(key_prefix)
140            .await?
141            .map(|kv| kv.0)
142            .collect::<Vec<_>>()
143            .await;
144        for key in keys {
145            self.raw_remove_entry(key.as_slice()).await?;
146        }
147        Ok(())
148    }
149
150    async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<PrefixStream<'_>> {
151        let data = self
152            .tx_data
153            .range((key_prefix.to_vec())..)
154            .take_while(|(key, _)| key.starts_with(key_prefix))
155            .map(|(key, value)| (key.clone(), value.clone()))
156            .collect::<Vec<_>>();
157        Ok(Box::pin(stream::iter(data)))
158    }
159
160    async fn raw_find_by_prefix_sorted_descending(
161        &mut self,
162        key_prefix: &[u8],
163    ) -> DatabaseResult<PrefixStream<'_>> {
164        let mut data = self
165            .tx_data
166            .range((key_prefix.to_vec())..)
167            .take_while(|(key, _)| key.starts_with(key_prefix))
168            .map(|(key, value)| (key.clone(), value.clone()))
169            .collect::<Vec<_>>();
170        data.sort_by(|a, b| a.cmp(b).reverse());
171
172        Ok(Box::pin(stream::iter(data)))
173    }
174
175    async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> DatabaseResult<PrefixStream<'_>> {
176        let data = self
177            .tx_data
178            .range(Range {
179                start: range.start.to_vec(),
180                end: range.end.to_vec(),
181            })
182            .map(|(key, value)| (key.clone(), value.clone()))
183            .collect::<Vec<_>>();
184        Ok(Box::pin(stream::iter(data)))
185    }
186}
187
188#[apply(async_trait_maybe_send!)]
189impl IDatabaseTransactionOps for MemTransaction<'_> {
190    async fn rollback_tx_to_savepoint(&mut self) -> DatabaseResult<()> {
191        self.tx_data = self.savepoint.clone();
192
193        // Remove any pending operations beyond the savepoint
194        let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
195        for _i in 0..removed_ops {
196            self.operations.pop();
197        }
198
199        Ok(())
200    }
201
202    async fn set_tx_savepoint(&mut self) -> DatabaseResult<()> {
203        self.savepoint = self.tx_data.clone();
204        self.num_savepoint_operations = self.num_pending_operations;
205        Ok(())
206    }
207}
208
209#[apply(async_trait_maybe_send!)]
210impl IRawDatabaseTransaction for MemTransaction<'_> {
211    #[allow(clippy::significant_drop_tightening)]
212    async fn commit_tx(self) -> DatabaseResult<()> {
213        let mut data = self.db.data.write().expect("Poisoned rwlock");
214        let mut data_copy = data.clone();
215        for op in self.operations {
216            match op {
217                DatabaseOperation::Insert(insert_op) => {
218                    if data_copy.insert(insert_op.key, insert_op.value) != insert_op.old_value {
219                        return Err(DatabaseError::WriteConflict);
220                    }
221                }
222                DatabaseOperation::Delete(delete_op) => {
223                    if data_copy.remove(&delete_op.key) != delete_op.old_value {
224                        return Err(DatabaseError::WriteConflict);
225                    }
226                }
227            }
228        }
229        *data = data_copy;
230        Ok(())
231    }
232}
233
234#[cfg(test)]
235mod tests;