fedimint_core/db/
mem_impl.rs1use std::fmt::{self, Debug};
2use std::ops::Range;
3use std::path::Path;
4
5use anyhow::Result;
6use futures::{StreamExt, stream};
7use hex::ToHex;
8use imbl::OrdMap;
9use macro_rules_attribute::apply;
10
11use super::{
12 IDatabaseTransactionOps, IDatabaseTransactionOpsCore, IRawDatabase, IRawDatabaseTransaction,
13};
14use crate::async_trait_maybe_send;
15use crate::db::PrefixStream;
16
17#[derive(Debug, Default)]
18pub struct DatabaseInsertOperation {
19 pub key: Vec<u8>,
20 pub value: Vec<u8>,
21 pub old_value: Option<Vec<u8>>,
22}
23
24#[derive(Debug, Default)]
25pub struct DatabaseDeleteOperation {
26 pub key: Vec<u8>,
27 pub old_value: Option<Vec<u8>>,
28}
29
30#[derive(Debug)]
31pub enum DatabaseOperation {
32 Insert(DatabaseInsertOperation),
33 Delete(DatabaseDeleteOperation),
34}
35
36#[derive(Default)]
37pub struct MemDatabase {
38 data: tokio::sync::RwLock<OrdMap<Vec<u8>, Vec<u8>>>,
39}
40
41impl fmt::Debug for MemDatabase {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.write_fmt(format_args!("MemDatabase {{}}",))
44 }
45}
46pub struct MemTransaction<'a> {
47 operations: Vec<DatabaseOperation>,
48 tx_data: OrdMap<Vec<u8>, Vec<u8>>,
49 db: &'a MemDatabase,
50 savepoint: OrdMap<Vec<u8>, Vec<u8>>,
51 num_pending_operations: usize,
52 num_savepoint_operations: usize,
53}
54
55impl<'a> fmt::Debug for MemTransaction<'a> {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.write_fmt(format_args!(
58 "MemTransaction {{ db={:?}, operations_len={}, tx_data_len={}, savepoint_len={}, num_pending_ops={}, num_savepoint_ops={} }}",
59 self.db,
60 self.operations.len(),
61 self.tx_data.len(),
62 self.savepoint.len(),
63 self.num_pending_operations,
64 self.num_savepoint_operations,
65 ))
66 }
67}
68
69#[derive(Debug, Eq, PartialEq)]
70pub struct DummyError;
71
72impl MemDatabase {
73 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub async fn dump_db(&self) {
78 #[allow(clippy::significant_drop_in_scrutinee)]
79 for (key, value) in self.data.read().await.iter() {
80 println!(
81 "{}: {}",
82 key.encode_hex::<String>(),
83 value.encode_hex::<String>()
84 );
85 }
86 }
87}
88
89#[apply(async_trait_maybe_send!)]
90impl IRawDatabase for MemDatabase {
91 type Transaction<'a> = MemTransaction<'a>;
92 async fn begin_transaction<'a>(&'a self) -> MemTransaction<'a> {
93 let db_copy = self.data.read().await.clone();
94 let mut memtx = MemTransaction {
95 operations: Vec::new(),
96 tx_data: db_copy.clone(),
97 db: self,
98 savepoint: db_copy,
99 num_pending_operations: 0,
100 num_savepoint_operations: 0,
101 };
102
103 memtx.set_tx_savepoint().await.expect("can't fail");
104 memtx
105 }
106
107 fn checkpoint(&self, _backup_path: &Path) -> Result<()> {
108 Ok(())
109 }
110}
111
112#[apply(async_trait_maybe_send!)]
115impl<'a> IDatabaseTransactionOpsCore for MemTransaction<'a> {
116 async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
117 let old_value = self.tx_data.insert(key.to_vec(), value.to_owned());
119 self.operations
120 .push(DatabaseOperation::Insert(DatabaseInsertOperation {
121 key: key.to_vec(),
122 value: value.to_owned(),
123 old_value: old_value.clone(),
124 }));
125 self.num_pending_operations += 1;
126 Ok(old_value)
127 }
128
129 async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
130 Ok(self.tx_data.get(key).cloned())
131 }
132
133 async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
134 let old_value = self.tx_data.remove(&key.to_vec());
136 self.operations
137 .push(DatabaseOperation::Delete(DatabaseDeleteOperation {
138 key: key.to_vec(),
139 old_value: old_value.clone(),
140 }));
141 self.num_pending_operations += 1;
142 Ok(old_value)
143 }
144
145 async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
146 let keys = self
147 .raw_find_by_prefix(key_prefix)
148 .await?
149 .map(|kv| kv.0)
150 .collect::<Vec<_>>()
151 .await;
152 for key in keys {
153 self.raw_remove_entry(key.as_slice()).await?;
154 }
155 Ok(())
156 }
157
158 async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
159 let data = self
160 .tx_data
161 .range((key_prefix.to_vec())..)
162 .take_while(|(key, _)| key.starts_with(key_prefix))
163 .map(|(key, value)| (key.clone(), value.clone()))
164 .collect::<Vec<_>>();
165 Ok(Box::pin(stream::iter(data)))
166 }
167
168 async fn raw_find_by_prefix_sorted_descending(
169 &mut self,
170 key_prefix: &[u8],
171 ) -> Result<PrefixStream<'_>> {
172 let mut data = self
173 .tx_data
174 .range((key_prefix.to_vec())..)
175 .take_while(|(key, _)| key.starts_with(key_prefix))
176 .map(|(key, value)| (key.clone(), value.clone()))
177 .collect::<Vec<_>>();
178 data.sort_by(|a, b| a.cmp(b).reverse());
179
180 Ok(Box::pin(stream::iter(data)))
181 }
182
183 async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
184 let data = self
185 .tx_data
186 .range(Range {
187 start: range.start.to_vec(),
188 end: range.end.to_vec(),
189 })
190 .map(|(key, value)| (key.clone(), value.clone()))
191 .collect::<Vec<_>>();
192 Ok(Box::pin(stream::iter(data)))
193 }
194}
195
196#[apply(async_trait_maybe_send!)]
197impl<'a> IDatabaseTransactionOps for MemTransaction<'a> {
198 async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
199 self.tx_data = self.savepoint.clone();
200
201 let removed_ops = self.num_pending_operations - self.num_savepoint_operations;
203 for _i in 0..removed_ops {
204 self.operations.pop();
205 }
206
207 Ok(())
208 }
209
210 async fn set_tx_savepoint(&mut self) -> Result<()> {
211 self.savepoint = self.tx_data.clone();
212 self.num_savepoint_operations = self.num_pending_operations;
213 Ok(())
214 }
215}
216
217#[apply(async_trait_maybe_send!)]
218impl<'a> IRawDatabaseTransaction for MemTransaction<'a> {
219 #[allow(clippy::significant_drop_tightening)]
220 async fn commit_tx(self) -> Result<()> {
221 let mut data = self.db.data.write().await;
222 let mut data_copy = data.clone();
223 for op in self.operations {
224 match op {
225 DatabaseOperation::Insert(insert_op) => {
226 anyhow::ensure!(
227 data_copy.insert(insert_op.key, insert_op.value) == insert_op.old_value,
228 "write-write conflict"
229 );
230 }
231 DatabaseOperation::Delete(delete_op) => {
232 anyhow::ensure!(
233 data_copy.remove(&delete_op.key) == delete_op.old_value,
234 "write-write conflict"
235 );
236 }
237 }
238 }
239 *data = data_copy;
240 Ok(())
241 }
242}
243
244#[cfg(test)]
245mod tests;