fedimint_core/db/
mem_impl.rs1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use futures::{StreamExt, stream};
6use imbl::OrdMap;
7use macro_rules_attribute::apply;
8
9use super::{
10 DatabaseError, DatabaseResult, IDatabaseTransactionOps, IDatabaseTransactionOpsCore,
11 IRawDatabase, IRawDatabaseTransaction,
12};
13use crate::async_trait_maybe_send;
14use crate::db::PrefixStream;
15
16#[derive(Debug, Default)]
17pub struct DatabaseInsertOperation {
18 pub key: Vec<u8>,
19 pub value: Vec<u8>,
20 pub old_value: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Default)]
24pub struct DatabaseDeleteOperation {
25 pub key: Vec<u8>,
26 pub old_value: Option<Vec<u8>>,
27}
28
29#[derive(Debug)]
30pub enum DatabaseOperation {
31 Insert(DatabaseInsertOperation),
32 Delete(DatabaseDeleteOperation),
33}
34
35#[derive(Default)]
36pub struct MemDatabase {
37 data: std::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
38}
39
40impl fmt::Debug for MemDatabase {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.write_fmt(format_args!("MemDatabase {{}}",))
43 }
44}
45pub struct MemTransaction<'a> {
46 operations: Vec<DatabaseOperation>,
47 tx_data: OrdMap<Vec<u8>, Vec<u8>>,
48 db: &'a MemDatabase,
49 savepoint: OrdMap<Vec<u8>, Vec<u8>>,
50 num_pending_operations: usize,
51 num_savepoint_operations: usize,
52}
53
54impl fmt::Debug for MemTransaction<'_> {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.write_fmt(format_args!(
57 "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
58 self.db,
59 self.operations.len(),
60 self.tx_data.len(),
61 self.savepoint.len(),
62 self.num_pending_operations,
63 self.num_savepoint_operations,
64 ))
65 }
66}
67
68#[derive(Debug, Eq, PartialEq)]
69pub struct DummyError;
70
71impl MemDatabase {
72 pub fn new() -> Self {
73 Self::default()
74 }
75}
76
77#[apply(async_trait_maybe_send!)]
78impl IRawDatabase for MemDatabase {
79 type Transaction<'a> = MemTransaction<'a>;
80 async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
81 let db_copy = self.data.read().expect("Poisoned rwlock").clone();
82 let mut memtx = MemTransaction {
83 operations: Vec::new(),
84 tx_data: db_copy.clone(),
85 db: self,
86 savepoint: db_copy,
87 num_pending_operations: 0,
88 num_savepoint_operations: 0,
89 };
90
91 memtx.set_tx_savepoint().await.expect("can't fail");
92 memtx
93 }
94
95 fn checkpoint(&self, _backup_path: &Path) -> DatabaseResult<()> {
96 Ok(())
97 }
98}
99
100#[apply(async_trait_maybe_send!)]
103impl IDatabaseTransactionOpsCore for MemTransaction<'_> {
104 async fn raw_insert_bytes(
105 &mut self,
106 key: &[u8],
107 value: &[u8],
108 ) -> DatabaseResult<Option<Vec<u8>>> {
109 let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
111 self.operations
112 .push(DatabaseOperation::Insert(DatabaseInsertOperation {
113 key: key.to_vec(),
114 value: value.to_owned(),
115 old_value: old_value.clone(),
116 }));
117 self.num_pending_operations += 1;
118 Ok(old_value)
119 }
120
121 async fn raw_get_bytes(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
122 Ok(self.tx_data.get(key).cloned())
123 }
124
125 async fn raw_remove_entry(&mut self, key: &[u8]) -> DatabaseResult<Option<Vec<u8>>> {
126 let old_value = self.tx_data.remove(&key.to_vec());
128 self.operations
129 .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
130 key: key.to_vec(),
131 old_value: old_value.clone(),
132 }));
133 self.num_pending_operations += 1;
134 Ok(old_value)
135 }
136
137 async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<()> {
138 let keys = self
139 .raw_find_by_prefix(key_prefix)
140 .await?
141 .map(|kv| kv.0)
142 .collect::<Vec<_>>()
143 .await;
144 for key in keys {
145 self.raw_remove_entry(key.as_slice()).await?;
146 }
147 Ok(())
148 }
149
150 async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> DatabaseResult<PrefixStream<'_>> {
151 let data = self
152 .tx_data
153 .range((key_prefix.to_vec())..)
154 .take_while(|(key, _)| key.starts_with(key_prefix))
155 .map(|(key, value)| (key.clone(), value.clone()))
156 .collect::<Vec<_>>();
157 Ok(Box::pin(stream::iter(data)))
158 }
159
160 async fn raw_find_by_prefix_sorted_descending(
161 &mut self,
162 key_prefix: &[u8],
163 ) -> DatabaseResult<PrefixStream<'_>> {
164 let mut data = self
165 .tx_data
166 .range((key_prefix.to_vec())..)
167 .take_while(|(key, _)| key.starts_with(key_prefix))
168 .map(|(key, value)| (key.clone(), value.clone()))
169 .collect::<Vec<_>>();
170 data.sort_by(|a, b| a.cmp(b).reverse());
171
172 Ok(Box::pin(stream::iter(data)))
173 }
174
175 async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> DatabaseResult<PrefixStream<'_>> {
176 let data = self
177 .tx_data
178 .range(Range {
179 start: range.start.to_vec(),
180 end: range.end.to_vec(),
181 })
182 .map(|(key, value)| (key.clone(), value.clone()))
183 .collect::<Vec<_>>();
184 Ok(Box::pin(stream::iter(data)))
185 }
186}
187
188#[apply(async_trait_maybe_send!)]
189impl IDatabaseTransactionOps for MemTransaction<'_> {
190 async fn rollback_tx_to_savepoint(&mut self) -> DatabaseResult<()> {
191 self.tx_data = self.savepoint.clone();
192
193 let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
195 for _i in 0..removed_ops {
196 self.operations.pop();
197 }
198
199 Ok(())
200 }
201
202 async fn set_tx_savepoint(&mut self) -> DatabaseResult<()> {
203 self.savepoint = self.tx_data.clone();
204 self.num_savepoint_operations = self.num_pending_operations;
205 Ok(())
206 }
207}
208
209#[apply(async_trait_maybe_send!)]
210impl IRawDatabaseTransaction for MemTransaction<'_> {
211 #[allow(clippy::significant_drop_tightening)]
212 async fn commit_tx(self) -> DatabaseResult<()> {
213 let mut data = self.db.data.write().expect("Poisoned rwlock");
214 let mut data_copy = data.clone();
215 for op in self.operations {
216 match op {
217 DatabaseOperation::Insert(insert_op) => {
218 if data_copy.insert(insert_op.key, insert_op.value) != insert_op.old_value {
219 return Err(DatabaseError::WriteConflict);
220 }
221 }
222 DatabaseOperation::Delete(delete_op) => {
223 if data_copy.remove(&delete_op.key) != delete_op.old_value {
224 return Err(DatabaseError::WriteConflict);
225 }
226 }
227 }
228 }
229 *data = data_copy;
230 Ok(())
231 }
232}
233
234#[cfg(test)]
235mod tests;