fedimint_core/db/
mem_impl.rs1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use anyhow::Result;
6use futures::{StreamExt, stream};
7use imbl::OrdMap;
8use macro_rules_attribute::apply;
9
10use super::{
11 IDatabaseTransactionOps, IDatabaseTransactionOpsCore, IRawDatabase, IRawDatabaseTransaction,
12};
13use crate::async_trait_maybe_send;
14use crate::db::PrefixStream;
15
16#[derive(Debug, Default)]
17pub struct DatabaseInsertOperation {
18 pub key: Vec<u8>,
19 pub value: Vec<u8>,
20 pub old_value: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Default)]
24pub struct DatabaseDeleteOperation {
25 pub key: Vec<u8>,
26 pub old_value: Option<Vec<u8>>,
27}
28
29#[derive(Debug)]
30pub enum DatabaseOperation {
31 Insert(DatabaseInsertOperation),
32 Delete(DatabaseDeleteOperation),
33}
34
35#[derive(Default)]
36pub struct MemDatabase {
37 data: std::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
38}
39
40impl fmt::Debug for MemDatabase {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.write_fmt(format_args!("MemDatabase {{}}",))
43 }
44}
45pub struct MemTransaction<'a> {
46 operations: Vec<DatabaseOperation>,
47 tx_data: OrdMap<Vec<u8>, Vec<u8>>,
48 db: &'a MemDatabase,
49 savepoint: OrdMap<Vec<u8>, Vec<u8>>,
50 num_pending_operations: usize,
51 num_savepoint_operations: usize,
52}
53
54impl fmt::Debug for MemTransaction<'_> {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.write_fmt(format_args!(
57 "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
58 self.db,
59 self.operations.len(),
60 self.tx_data.len(),
61 self.savepoint.len(),
62 self.num_pending_operations,
63 self.num_savepoint_operations,
64 ))
65 }
66}
67
68#[derive(Debug, Eq, PartialEq)]
69pub struct DummyError;
70
71impl MemDatabase {
72 pub fn new() -> Self {
73 Self::default()
74 }
75}
76
77#[apply(async_trait_maybe_send!)]
78impl IRawDatabase for MemDatabase {
79 type Transaction<'a> = MemTransaction<'a>;
80 async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
81 let db_copy = self.data.read().expect("Poisoned rwlock").clone();
82 let mut memtx = MemTransaction {
83 operations: Vec::new(),
84 tx_data: db_copy.clone(),
85 db: self,
86 savepoint: db_copy,
87 num_pending_operations: 0,
88 num_savepoint_operations: 0,
89 };
90
91 memtx.set_tx_savepoint().await.expect("can't fail");
92 memtx
93 }
94
95 fn checkpoint(&self, _backup_path: &Path) -> Result<()> {
96 Ok(())
97 }
98}
99
100#[apply(async_trait_maybe_send!)]
103impl IDatabaseTransactionOpsCore for MemTransaction<'_> {
104 async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
105 let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
107 self.operations
108 .push(DatabaseOperation::Insert(DatabaseInsertOperation {
109 key: key.to_vec(),
110 value: value.to_owned(),
111 old_value: old_value.clone(),
112 }));
113 self.num_pending_operations += 1;
114 Ok(old_value)
115 }
116
117 async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
118 Ok(self.tx_data.get(key).cloned())
119 }
120
121 async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
122 let old_value = self.tx_data.remove(&key.to_vec());
124 self.operations
125 .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
126 key: key.to_vec(),
127 old_value: old_value.clone(),
128 }));
129 self.num_pending_operations += 1;
130 Ok(old_value)
131 }
132
133 async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
134 let keys = self
135 .raw_find_by_prefix(key_prefix)
136 .await?
137 .map(|kv| kv.0)
138 .collect::<Vec<_>>()
139 .await;
140 for key in keys {
141 self.raw_remove_entry(key.as_slice()).await?;
142 }
143 Ok(())
144 }
145
146 async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
147 let data = self
148 .tx_data
149 .range((key_prefix.to_vec())..)
150 .take_while(|(key, _)| key.starts_with(key_prefix))
151 .map(|(key, value)| (key.clone(), value.clone()))
152 .collect::<Vec<_>>();
153 Ok(Box::pin(stream::iter(data)))
154 }
155
156 async fn raw_find_by_prefix_sorted_descending(
157 &mut self,
158 key_prefix: &[u8],
159 ) -> Result<PrefixStream<'_>> {
160 let mut data = self
161 .tx_data
162 .range((key_prefix.to_vec())..)
163 .take_while(|(key, _)| key.starts_with(key_prefix))
164 .map(|(key, value)| (key.clone(), value.clone()))
165 .collect::<Vec<_>>();
166 data.sort_by(|a, b| a.cmp(b).reverse());
167
168 Ok(Box::pin(stream::iter(data)))
169 }
170
171 async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
172 let data = self
173 .tx_data
174 .range(Range {
175 start: range.start.to_vec(),
176 end: range.end.to_vec(),
177 })
178 .map(|(key, value)| (key.clone(), value.clone()))
179 .collect::<Vec<_>>();
180 Ok(Box::pin(stream::iter(data)))
181 }
182}
183
184#[apply(async_trait_maybe_send!)]
185impl IDatabaseTransactionOps for MemTransaction<'_> {
186 async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
187 self.tx_data = self.savepoint.clone();
188
189 let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
191 for _i in 0..removed_ops {
192 self.operations.pop();
193 }
194
195 Ok(())
196 }
197
198 async fn set_tx_savepoint(&mut self) -> Result<()> {
199 self.savepoint = self.tx_data.clone();
200 self.num_savepoint_operations = self.num_pending_operations;
201 Ok(())
202 }
203}
204
205#[apply(async_trait_maybe_send!)]
206impl IRawDatabaseTransaction for MemTransaction<'_> {
207 #[allow(clippy::significant_drop_tightening)]
208 async fn commit_tx(self) -> Result<()> {
209 let mut data = self.db.data.write().expect("Poisoned rwlock");
210 let mut data_copy = data.clone();
211 for op in self.operations {
212 match op {
213 DatabaseOperation::Insert(insert_op) => {
214 anyhow::ensure!(
215 data_copy.insert(insert_op.key, insert_op.value) == insert_op.old_value,
216 "write-write conflict"
217 );
218 }
219 DatabaseOperation::Delete(delete_op) => {
220 anyhow::ensure!(
221 data_copy.remove(&delete_op.key) == delete_op.old_value,
222 "write-write conflict"
223 );
224 }
225 }
226 }
227 *data = data_copy;
228 Ok(())
229 }
230}
231
232#[cfg(test)]
233mod tests;