fedimint_core/db/
mem_impl.rs

1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use anyhow::Result;
6use futures::{StreamExt, stream};
7use hex::ToHex;
8use imbl::OrdMap;
9use macro_rules_attribute::apply;
10
11use super::{
12    IDatabaseTransactionOps, IDatabaseTransactionOpsCore, IRawDatabase, IRawDatabaseTransaction,
13};
14use crate::async_trait_maybe_send;
15use crate::db::PrefixStream;
16
17#[derive(Debug, Default)]
18pub struct DatabaseInsertOperation {
19    pub key: Vec<u8>,
20    pub value: Vec<u8>,
21    pub old_value: Option<Vec<u8>>,
22}
23
24#[derive(Debug, Default)]
25pub struct DatabaseDeleteOperation {
26    pub key: Vec<u8>,
27    pub old_value: Option<Vec<u8>>,
28}
29
30#[derive(Debug)]
31pub enum DatabaseOperation {
32    Insert(DatabaseInsertOperation),
33    Delete(DatabaseDeleteOperation),
34}
35
36#[derive(Default)]
37pub struct MemDatabase {
38    data: tokio::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
39}
40
41impl fmt::Debug for MemDatabase {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.write_fmt(format_args!("MemDatabase {{}}",))
44    }
45}
46pub struct MemTransaction<'a> {
47    operations: Vec<DatabaseOperation>,
48    tx_data: OrdMap<Vec<u8>, Vec<u8>>,
49    db: &'a MemDatabase,
50    savepoint: OrdMap<Vec<u8>, Vec<u8>>,
51    num_pending_operations: usize,
52    num_savepoint_operations: usize,
53}
54
55impl<'a> fmt::Debug for MemTransaction<'a> {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.write_fmt(format_args!(
58            "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
59            self.db,
60            self.operations.len(),
61            self.tx_data.len(),
62            self.savepoint.len(),
63            self.num_pending_operations,
64            self.num_savepoint_operations,
65        ))
66    }
67}
68
69#[derive(Debug, Eq, PartialEq)]
70pub struct DummyError;
71
72impl MemDatabase {
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    pub async fn dump_db(&self) {
78        #[allow(clippy::significant_drop_in_scrutinee)]
79        for (key, value) in self.data.read().await.iter() {
80            println!(
81                "{}: {}",
82                key.encode_hex::<String>(),
83                value.encode_hex::<String>()
84            );
85        }
86    }
87}
88
89#[apply(async_trait_maybe_send!)]
90impl IRawDatabase for MemDatabase {
91    type Transaction<'a> = MemTransaction<'a>;
92    async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
93        let db_copy = self.data.read().await.clone();
94        let mut memtx = MemTransaction {
95            operations: Vec::new(),
96            tx_data: db_copy.clone(),
97            db: self,
98            savepoint: db_copy,
99            num_pending_operations: 0,
100            num_savepoint_operations: 0,
101        };
102
103        memtx.set_tx_savepoint().await.expect("can't fail");
104        memtx
105    }
106
107    fn checkpoint(&self, _backup_path: &Path) -> Result<()> {
108        Ok(())
109    }
110}
111
112// In-memory database transaction should only be used for test code and never
113// for production as it doesn't properly implement MVCC
114#[apply(async_trait_maybe_send!)]
115impl<'a> IDatabaseTransactionOpsCore for MemTransaction<'a> {
116    async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
117        // Insert data from copy so we can read our own writes
118        let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
119        self.operations
120            .push(DatabaseOperation::Insert(DatabaseInsertOperation {
121                key: key.to_vec(),
122                value: value.to_owned(),
123                old_value: old_value.clone(),
124            }));
125        self.num_pending_operations += 1;
126        Ok(old_value)
127    }
128
129    async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
130        Ok(self.tx_data.get(key).cloned())
131    }
132
133    async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
134        // Remove data from copy so we can read our own writes
135        let old_value = self.tx_data.remove(&key.to_vec());
136        self.operations
137            .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
138                key: key.to_vec(),
139                old_value: old_value.clone(),
140            }));
141        self.num_pending_operations += 1;
142        Ok(old_value)
143    }
144
145    async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
146        let keys = self
147            .raw_find_by_prefix(key_prefix)
148            .await?
149            .map(|kv| kv.0)
150            .collect::<Vec<_>>()
151            .await;
152        for key in keys {
153            self.raw_remove_entry(key.as_slice()).await?;
154        }
155        Ok(())
156    }
157
158    async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
159        let data = self
160            .tx_data
161            .range((key_prefix.to_vec())..)
162            .take_while(|(key, _)| key.starts_with(key_prefix))
163            .map(|(key, value)| (key.clone(), value.clone()))
164            .collect::<Vec<_>>();
165        Ok(Box::pin(stream::iter(data)))
166    }
167
168    async fn raw_find_by_prefix_sorted_descending(
169        &mut self,
170        key_prefix: &[u8],
171    ) -> Result<PrefixStream<'_>> {
172        let mut data = self
173            .tx_data
174            .range((key_prefix.to_vec())..)
175            .take_while(|(key, _)| key.starts_with(key_prefix))
176            .map(|(key, value)| (key.clone(), value.clone()))
177            .collect::<Vec<_>>();
178        data.sort_by(|a, b| a.cmp(b).reverse());
179
180        Ok(Box::pin(stream::iter(data)))
181    }
182
183    async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
184        let data = self
185            .tx_data
186            .range(Range {
187                start: range.start.to_vec(),
188                end: range.end.to_vec(),
189            })
190            .map(|(key, value)| (key.clone(), value.clone()))
191            .collect::<Vec<_>>();
192        Ok(Box::pin(stream::iter(data)))
193    }
194}
195
196#[apply(async_trait_maybe_send!)]
197impl<'a> IDatabaseTransactionOps for MemTransaction<'a> {
198    async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
199        self.tx_data = self.savepoint.clone();
200
201        // Remove any pending operations beyond the savepoint
202        let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
203        for _i in 0..removed_ops {
204            self.operations.pop();
205        }
206
207        Ok(())
208    }
209
210    async fn set_tx_savepoint(&mut self) -> Result<()> {
211        self.savepoint = self.tx_data.clone();
212        self.num_savepoint_operations = self.num_pending_operations;
213        Ok(())
214    }
215}
216
217#[apply(async_trait_maybe_send!)]
218impl<'a> IRawDatabaseTransaction for MemTransaction<'a> {
219    #[allow(clippy::significant_drop_tightening)]
220    async fn commit_tx(self) -> Result<()> {
221        let mut data = self.db.data.write().await;
222        let mut data_copy = data.clone();
223        for op in self.operations {
224            match op {
225                DatabaseOperation::Insert(insert_op) => {
226                    anyhow::ensure!(
227                        data_copy.insert(insert_op.key, insert_op.value) == insert_op.old_value,
228                        "write-write conflict"
229                    );
230                }
231                DatabaseOperation::Delete(delete_op) => {
232                    anyhow::ensure!(
233                        data_copy.remove(&delete_op.key) == delete_op.old_value,
234                        "write-write conflict"
235                    );
236                }
237            }
238        }
239        *data = data_copy;
240        Ok(())
241    }
242}
243
244#[cfg(test)]
245mod tests;