fedimint_core/db/
mem_impl.rs

1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use futures::{StreamExt, stream};
6use imbl::OrdMap;
7use macro_rules_attribute::apply;
8
9use super::{
10    DatabaseError, DatabaseResult, IDatabaseTransactionOps, IDatabaseTransactionOpsCore,
11    IRawDatabase, IRawDatabaseTransaction,
12};
13use crate::async_trait_maybe_send;
14use crate::db::PrefixStream;
15
16#[derive(Debug, Default)]
17pub struct DatabaseInsertOperation {
18    pub key: Vec<u8>,
19    pub value: Vec<u8>,
20    pub old_value: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Default)]
24pub struct DatabaseDeleteOperation {
25    pub key: Vec<u8>,
26    pub old_value: Option<Vec<u8>>,
27}
28
29#[derive(Debug)]
30pub enum DatabaseOperation {
31    Insert(DatabaseInsertOperation),
32    Delete(DatabaseDeleteOperation),
33}
34
35#[derive(Default)]
36pub struct MemDatabase {
37    data: std::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
38}
39
40impl fmt::Debug for MemDatabase {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_fmt(format_args!("MemDatabase {{}}",))
43    }
44}
45pub struct MemTransaction<'a> {
46    operations: Vec<DatabaseOperation>,
47    tx_data: OrdMap<Vec<u8>, Vec<u8>>,
48    db: &'a MemDatabase,
49}
50
51impl fmt::Debug for MemTransaction<'_> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.write_fmt(format_args!(
54            "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={} }}",
55            self.db,
56            self.operations.len(),
57            self.tx_data.len(),
58        ))
59    }
60}
61
62#[derive(Debug, Eq, PartialEq)]
63pub struct DummyError;
64
65impl MemDatabase {
66    pub fn new() -> Self {
67        Self::default()
68    }
69}
70
71#[apply(async_trait_maybe_send!)]
72impl IRawDatabase for MemDatabase {
73    type Transaction<'a> = MemTransaction<'a>;
74    async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
75        let db_copy = self.data.read().expect("Poisoned rwlock").clone();
76        MemTransaction {
77            operations: Vec::new(),
78            tx_data: db_copy,
79            db: self,
80        }
81    }
82
83    fn checkpoint(&self, _backup_path: &Path) -> DatabaseResult<()> {
84        Ok(())
85    }
86}
87
88// In-memory database transaction should only be used for test code and never
89// for production as it doesn't properly implement MVCC
90#[apply(async_trait_maybe_send!)]
91impl IDatabaseTransactionOpsCore for MemTransaction<'_> {
92    async fn raw_insert_bytes(
93        &mut self,
94        key: &[u8],
95        value: &[u8],
96    ) -> DatabaseResult<Option<Vec<u8>>> {
97        // Insert data from copy so we can read our own writes
98        let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
99        self.operations
100            .push(DatabaseOperation::Insert(DatabaseInsertOperation {
101                key: key.to_vec(),
102                value: value.to_owned(),
103                old_value: old_value.clone(),
104            }));
105        Ok(old_value)
106    }
107
108    async fn raw_get_bytes(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
109        Ok(self.tx_data.get(key).cloned())
110    }
111
112    async fn raw_remove_entry(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
113        // Remove data from copy so we can read our own writes
114        let old_value = self.tx_data.remove(&key.to_vec());
115        self.operations
116            .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
117                key: key.to_vec(),
118                old_value: old_value.clone(),
119            }));
120        Ok(old_value)
121    }
122
123    async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<()> {
124        let keys = self
125            .raw_find_by_prefix(key_prefix)
126            .await?
127            .map(|kv| kv.0)
128            .collect::<Vec<_>>()
129            .await;
130        for key in keys {
131            self.raw_remove_entry(key.as_slice()).await?;
132        }
133        Ok(())
134    }
135
136    async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<PrefixStream<'_>> {
137        let data = self
138            .tx_data
139            .range((key_prefix.to_vec())..)
140            .take_while(|(key, _)| key.starts_with(key_prefix))
141            .map(|(key, value)| (key.clone(), value.clone()))
142            .collect::<Vec<_>>();
143        Ok(Box::pin(stream::iter(data)))
144    }
145
146    async fn raw_find_by_prefix_sorted_descending(
147        &mut self,
148        key_prefix: &[u8],
149    ) -> DatabaseResult<PrefixStream<'_>> {
150        let mut data = self
151            .tx_data
152            .range((key_prefix.to_vec())..)
153            .take_while(|(key, _)| key.starts_with(key_prefix))
154            .map(|(key, value)| (key.clone(), value.clone()))
155            .collect::<Vec<_>>();
156        data.sort_by(|a, b| a.cmp(b).reverse());
157
158        Ok(Box::pin(stream::iter(data)))
159    }
160
161    async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> DatabaseResult<PrefixStream<'_>> {
162        let data = self
163            .tx_data
164            .range(Range {
165                start: range.start.to_vec(),
166                end: range.end.to_vec(),
167            })
168            .map(|(key, value)| (key.clone(), value.clone()))
169            .collect::<Vec<_>>();
170        Ok(Box::pin(stream::iter(data)))
171    }
172}
173
174impl IDatabaseTransactionOps for MemTransaction<'_> {}
175
176#[apply(async_trait_maybe_send!)]
177impl IRawDatabaseTransaction for MemTransaction<'_> {
178    #[allow(clippy::significant_drop_tightening)]
179    async fn commit_tx(self) -> DatabaseResult<()> {
180        let mut data = self.db.data.write().expect("Poisoned rwlock");
181        let mut data_copy = data.clone();
182        for op in self.operations {
183            match op {
184                DatabaseOperation::Insert(insert_op) => {
185                    if data_copy.insert(insert_op.key, insert_op.value) != insert_op.old_value {
186                        return Err(DatabaseError::WriteConflict);
187                    }
188                }
189                DatabaseOperation::Delete(delete_op) => {
190                    if data_copy.remove(&delete_op.key) != delete_op.old_value {
191                        return Err(DatabaseError::WriteConflict);
192                    }
193                }
194            }
195        }
196        *data = data_copy;
197        Ok(())
198    }
199}
200
201#[cfg(test)]
202mod tests;