use std::any;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt::{self, Debug};
use std::marker::{self, PhantomData};
use std::ops::{self, DerefMut, Range};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, Context, Result};
use fedimint_core::util::BoxFuture;
use fedimint_logging::LOG_DB;
use futures::{Stream, StreamExt};
use macro_rules_attribute::apply;
use rand::Rng;
use serde::Serialize;
use strum_macros::EnumIter;
use thiserror::Error;
use tracing::{debug, error, info, instrument, trace, warn};
use crate::core::ModuleInstanceId;
use crate::encoding::{Decodable, Encodable};
use crate::fmt_utils::AbbreviateHexBytes;
use crate::task::{MaybeSend, MaybeSync};
use crate::{async_trait_maybe_send, maybe_add_send, timing};
pub mod mem_impl;
pub mod notifications;
pub use test_utils::*;
use self::notifications::{Notifications, NotifyQueue};
use crate::module::registry::{ModuleDecoderRegistry, ModuleRegistry};
pub const MODULE_GLOBAL_PREFIX: u8 = 0xff;
pub trait DatabaseKeyPrefix: Debug {
fn to_bytes(&self) -> Vec<u8>;
}
pub trait DatabaseRecord: DatabaseKeyPrefix {
const DB_PREFIX: u8;
const NOTIFY_ON_MODIFY: bool = false;
type Key: DatabaseKey + Debug;
type Value: DatabaseValue + Debug;
}
pub trait DatabaseLookup: DatabaseKeyPrefix {
type Record: DatabaseRecord;
}
impl<Record> DatabaseLookup for Record
where
Record: DatabaseRecord + Debug + Decodable + Encodable,
{
type Record = Record;
}
pub trait DatabaseKey: Sized {
const NOTIFY_ON_MODIFY: bool = false;
fn from_bytes(data: &[u8], modules: &ModuleDecoderRegistry) -> Result<Self, DecodingError>;
}
pub trait DatabaseKeyWithNotify {}
pub trait DatabaseValue: Sized + Debug {
fn from_bytes(data: &[u8], modules: &ModuleDecoderRegistry) -> Result<Self, DecodingError>;
fn to_bytes(&self) -> Vec<u8>;
}
pub type PrefixStream<'a> = Pin<Box<maybe_add_send!(dyn Stream<Item = (Vec<u8>, Vec<u8>)> + 'a)>>;
pub type PhantomBound<'big, 'small> = PhantomData<&'small &'big ()>;
#[derive(Debug, Error)]
pub enum AutocommitError<E> {
#[error("Commit Failed: {last_error}")]
CommitFailed {
attempts: usize,
last_error: anyhow::Error,
},
#[error("Closure error: {error}")]
ClosureError {
attempts: usize,
error: E,
},
}
#[apply(async_trait_maybe_send!)]
pub trait IRawDatabase: Debug + MaybeSend + MaybeSync + 'static {
type Transaction<'a>: IRawDatabaseTransaction + Debug;
async fn begin_transaction<'a>(&'a self) -> Self::Transaction<'a>;
fn checkpoint(&self, backup_path: &Path) -> Result<()>;
}
#[apply(async_trait_maybe_send!)]
impl<T> IRawDatabase for Box<T>
where
T: IRawDatabase,
{
type Transaction<'a> = <T as IRawDatabase>::Transaction<'a>;
async fn begin_transaction<'a>(&'a self) -> Self::Transaction<'a> {
(**self).begin_transaction().await
}
fn checkpoint(&self, backup_path: &Path) -> Result<()> {
(**self).checkpoint(backup_path)
}
}
pub trait IRawDatabaseExt: IRawDatabase + Sized {
fn into_database(self) -> Database {
Database::new(self, ModuleRegistry::default())
}
}
impl<T> IRawDatabaseExt for T where T: IRawDatabase {}
impl<T> From<T> for Database
where
T: IRawDatabase,
{
fn from(raw: T) -> Self {
Self::new(raw, ModuleRegistry::default())
}
}
#[apply(async_trait_maybe_send!)]
pub trait IDatabase: Debug + MaybeSend + MaybeSync + 'static {
async fn begin_transaction<'a>(&'a self) -> Box<dyn IDatabaseTransaction + 'a>;
async fn register(&self, key: &[u8]);
async fn notify(&self, key: &[u8]);
fn is_global(&self) -> bool;
fn checkpoint(&self, backup_path: &Path) -> Result<()>;
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabase for Arc<T>
where
T: IDatabase + ?Sized,
{
async fn begin_transaction<'a>(&'a self) -> Box<dyn IDatabaseTransaction + 'a> {
(**self).begin_transaction().await
}
async fn register(&self, key: &[u8]) {
(**self).register(key).await;
}
async fn notify(&self, key: &[u8]) {
(**self).notify(key).await;
}
fn is_global(&self) -> bool {
(**self).is_global()
}
fn checkpoint(&self, backup_path: &Path) -> Result<()> {
(**self).checkpoint(backup_path)
}
}
struct BaseDatabase<RawDatabase> {
notifications: Arc<Notifications>,
raw: RawDatabase,
}
impl<RawDatabase> fmt::Debug for BaseDatabase<RawDatabase> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BaseDatabase")
}
}
#[apply(async_trait_maybe_send!)]
impl<RawDatabase: IRawDatabase + MaybeSend + 'static> IDatabase for BaseDatabase<RawDatabase> {
async fn begin_transaction<'a>(&'a self) -> Box<dyn IDatabaseTransaction + 'a> {
Box::new(BaseDatabaseTransaction::new(
self.raw.begin_transaction().await,
self.notifications.clone(),
))
}
async fn register(&self, key: &[u8]) {
self.notifications.register(key).await;
}
async fn notify(&self, key: &[u8]) {
self.notifications.notify(key);
}
fn is_global(&self) -> bool {
true
}
fn checkpoint(&self, backup_path: &Path) -> Result<()> {
self.raw.checkpoint(backup_path)
}
}
#[derive(Clone, Debug)]
pub struct Database {
inner: Arc<dyn IDatabase + 'static>,
module_decoders: ModuleDecoderRegistry,
}
impl Database {
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
pub fn into_inner(self) -> Arc<dyn IDatabase + 'static> {
self.inner
}
}
impl Database {
pub fn new(raw: impl IRawDatabase + 'static, module_decoders: ModuleDecoderRegistry) -> Self {
let inner = BaseDatabase {
raw,
notifications: Arc::new(Notifications::new()),
};
Self::new_from_arc(
Arc::new(inner) as Arc<dyn IDatabase + 'static>,
module_decoders,
)
}
pub fn new_from_arc(
inner: Arc<dyn IDatabase + 'static>,
module_decoders: ModuleDecoderRegistry,
) -> Self {
Self {
inner,
module_decoders,
}
}
pub fn with_prefix(&self, prefix: Vec<u8>) -> Self {
Self {
inner: Arc::new(PrefixDatabase {
inner: self.inner.clone(),
global_dbtx_access_token: None,
prefix,
}),
module_decoders: self.module_decoders.clone(),
}
}
pub fn with_prefix_module_id(
&self,
module_instance_id: ModuleInstanceId,
) -> (Self, GlobalDBTxAccessToken) {
let prefix = module_instance_id_to_byte_prefix(module_instance_id);
let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix);
(
Self {
inner: Arc::new(PrefixDatabase {
inner: self.inner.clone(),
global_dbtx_access_token: Some(global_dbtx_access_token),
prefix,
}),
module_decoders: self.module_decoders.clone(),
},
global_dbtx_access_token,
)
}
pub fn with_decoders(&self, module_decoders: ModuleDecoderRegistry) -> Self {
Self {
inner: self.inner.clone(),
module_decoders,
}
}
pub fn is_global(&self) -> bool {
self.inner.is_global()
}
pub fn ensure_global(&self) -> Result<()> {
if !self.is_global() {
bail!("Database instance not global");
}
Ok(())
}
pub fn ensure_isolated(&self) -> Result<()> {
if self.is_global() {
bail!("Database instance not isolated");
}
Ok(())
}
pub async fn begin_transaction<'s, 'tx>(&'s self) -> DatabaseTransaction<'tx, Committable>
where
's: 'tx,
{
DatabaseTransaction::<Committable>::new(
self.inner.begin_transaction().await,
self.module_decoders.clone(),
)
}
pub async fn begin_transaction_nc<'s, 'tx>(&'s self) -> DatabaseTransaction<'tx, NonCommittable>
where
's: 'tx,
{
self.begin_transaction().await.into_nc()
}
pub fn checkpoint(&self, backup_path: &Path) -> Result<()> {
self.inner.checkpoint(backup_path)
}
pub async fn autocommit<'s, 'dbtx, F, T, E>(
&'s self,
tx_fn: F,
max_attempts: Option<usize>,
) -> Result<T, AutocommitError<E>>
where
's: 'dbtx,
for<'r, 'o> F: Fn(
&'r mut DatabaseTransaction<'o>,
PhantomBound<'dbtx, 'o>,
) -> BoxFuture<'r, Result<T, E>>,
{
assert_ne!(max_attempts, Some(0));
let mut curr_attempts: usize = 0;
loop {
curr_attempts = curr_attempts
.checked_add(1)
.expect("db autocommit attempt counter overflowed");
let mut dbtx = self.begin_transaction().await;
let tx_fn_res = tx_fn(&mut dbtx.to_ref_nc(), PhantomData).await;
let val = match tx_fn_res {
Ok(val) => val,
Err(err) => {
dbtx.ignore_uncommitted();
return Err(AutocommitError::ClosureError {
attempts: curr_attempts,
error: err,
});
}
};
let _timing = timing::TimeReporter::new("autocommit - commit_tx");
match dbtx.commit_tx_result().await {
Ok(()) => {
return Ok(val);
}
Err(err) => {
if max_attempts.is_some_and(|max_att| max_att <= curr_attempts) {
warn!(
target: LOG_DB,
curr_attempts,
?err,
"Database commit failed in an autocommit block - terminating"
);
return Err(AutocommitError::CommitFailed {
attempts: curr_attempts,
last_error: err,
});
}
let delay = (2u64.pow(curr_attempts.min(7) as u32) * 10).min(1000);
let delay = rand::thread_rng().gen_range(delay..(2 * delay));
warn!(
target: LOG_DB,
curr_attempts,
%err,
delay_ms = %delay,
"Database commit failed in an autocommit block - retrying"
);
crate::runtime::sleep(Duration::from_millis(delay)).await;
}
}
}
}
pub async fn wait_key_check<'a, K, T>(
&'a self,
key: &K,
checker: impl Fn(Option<K::Value>) -> Option<T>,
) -> (T, DatabaseTransaction<'a, Committable>)
where
K: DatabaseKey + DatabaseRecord + DatabaseKeyWithNotify,
{
let key_bytes = key.to_bytes();
loop {
let notify = self.inner.register(&key_bytes);
let mut tx = self.inner.begin_transaction().await;
let maybe_value_bytes = tx
.raw_get_bytes(&key_bytes)
.await
.expect("Unrecoverable error when reading from database")
.map(|value_bytes| {
decode_value_expect(&value_bytes, &self.module_decoders, &key_bytes)
});
if let Some(value) = checker(maybe_value_bytes) {
return (
value,
DatabaseTransaction::new(tx, self.module_decoders.clone()),
);
}
notify.await;
}
}
pub async fn wait_key_exists<K>(&self, key: &K) -> K::Value
where
K: DatabaseKey + DatabaseRecord + DatabaseKeyWithNotify,
{
self.wait_key_check(key, std::convert::identity).await.0
}
}
fn module_instance_id_to_byte_prefix(module_instance_id: u16) -> Vec<u8> {
let mut bytes = vec![MODULE_GLOBAL_PREFIX];
bytes.append(&mut module_instance_id.consensus_encode_to_vec());
bytes
}
#[derive(Clone, Debug)]
struct PrefixDatabase<Inner>
where
Inner: Debug,
{
prefix: Vec<u8>,
global_dbtx_access_token: Option<GlobalDBTxAccessToken>,
inner: Inner,
}
impl<Inner> PrefixDatabase<Inner>
where
Inner: Debug,
{
fn get_full_key(&self, key: &[u8]) -> Vec<u8> {
let mut full_key = self.prefix.clone();
full_key.extend_from_slice(key);
full_key
}
}
#[apply(async_trait_maybe_send!)]
impl<Inner> IDatabase for PrefixDatabase<Inner>
where
Inner: Debug + MaybeSend + MaybeSync + 'static + IDatabase,
{
async fn begin_transaction<'a>(&'a self) -> Box<dyn IDatabaseTransaction + 'a> {
Box::new(PrefixDatabaseTransaction {
inner: self.inner.begin_transaction().await,
global_dbtx_access_token: self.global_dbtx_access_token,
prefix: self.prefix.clone(),
})
}
async fn register(&self, key: &[u8]) {
self.inner.register(&self.get_full_key(key)).await;
}
async fn notify(&self, key: &[u8]) {
self.inner.notify(&self.get_full_key(key)).await;
}
fn is_global(&self) -> bool {
if self.global_dbtx_access_token.is_some() {
false
} else {
self.inner.is_global()
}
}
fn checkpoint(&self, backup_path: &Path) -> Result<()> {
self.inner.checkpoint(backup_path)
}
}
#[derive(Debug)]
struct PrefixDatabaseTransaction<Inner> {
inner: Inner,
global_dbtx_access_token: Option<GlobalDBTxAccessToken>,
prefix: Vec<u8>,
}
impl<Inner> PrefixDatabaseTransaction<Inner> {
fn get_full_key(&self, key: &[u8]) -> Vec<u8> {
let mut full_key = self.prefix.clone();
full_key.extend_from_slice(key);
full_key
}
fn get_full_range(&self, range: Range<&[u8]>) -> Range<Vec<u8>> {
Range {
start: self.get_full_key(range.start),
end: self.get_full_key(range.end),
}
}
fn adapt_prefix_stream(stream: PrefixStream<'_>, prefix_len: usize) -> PrefixStream<'_> {
Box::pin(stream.map(move |(k, v)| (k[prefix_len..].to_owned(), v)))
}
}
#[apply(async_trait_maybe_send!)]
impl<Inner> IDatabaseTransaction for PrefixDatabaseTransaction<Inner>
where
Inner: IDatabaseTransaction,
{
async fn commit_tx(&mut self) -> Result<()> {
self.inner.commit_tx().await
}
fn is_global(&self) -> bool {
if self.global_dbtx_access_token.is_some() {
false
} else {
self.inner.is_global()
}
}
fn global_dbtx(
&mut self,
access_token: GlobalDBTxAccessToken,
) -> &mut dyn IDatabaseTransaction {
if let Some(self_global_dbtx_access_token) = self.global_dbtx_access_token {
assert_eq!(
access_token, self_global_dbtx_access_token,
"Invalid access key used to access global_dbtx"
);
&mut self.inner
} else {
self.inner.global_dbtx(access_token)
}
}
}
#[apply(async_trait_maybe_send!)]
impl<Inner> IDatabaseTransactionOpsCore for PrefixDatabaseTransaction<Inner>
where
Inner: IDatabaseTransactionOpsCore,
{
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
let key = self.get_full_key(key);
self.inner.raw_insert_bytes(&key, value).await
}
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let key = self.get_full_key(key);
self.inner.raw_get_bytes(&key).await
}
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let key = self.get_full_key(key);
self.inner.raw_remove_entry(&key).await
}
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
let key = self.get_full_key(key_prefix);
let stream = self.inner.raw_find_by_prefix(&key).await?;
Ok(Self::adapt_prefix_stream(stream, self.prefix.len()))
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>> {
let key = self.get_full_key(key_prefix);
let stream = self
.inner
.raw_find_by_prefix_sorted_descending(&key)
.await?;
Ok(Self::adapt_prefix_stream(stream, self.prefix.len()))
}
async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
let range = self.get_full_range(range);
let stream = self
.inner
.raw_find_by_range(Range {
start: &range.start,
end: &range.end,
})
.await?;
Ok(Self::adapt_prefix_stream(stream, self.prefix.len()))
}
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
let key = self.get_full_key(key_prefix);
self.inner.raw_remove_by_prefix(&key).await
}
}
#[apply(async_trait_maybe_send!)]
impl<Inner> IDatabaseTransactionOps for PrefixDatabaseTransaction<Inner>
where
Inner: IDatabaseTransactionOps,
{
async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
self.inner.rollback_tx_to_savepoint().await
}
async fn set_tx_savepoint(&mut self) -> Result<()> {
self.set_tx_savepoint().await
}
}
#[apply(async_trait_maybe_send!)]
pub trait IDatabaseTransactionOpsCore: MaybeSend {
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>>;
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>>;
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>>;
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>>;
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>>;
async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>>;
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()>;
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabaseTransactionOpsCore for Box<T>
where
T: IDatabaseTransactionOpsCore + ?Sized,
{
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_insert_bytes(key, value).await
}
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_get_bytes(key).await
}
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_remove_entry(key).await
}
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
(**self).raw_find_by_prefix(key_prefix).await
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>> {
(**self)
.raw_find_by_prefix_sorted_descending(key_prefix)
.await
}
async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
(**self).raw_find_by_range(range).await
}
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
(**self).raw_remove_by_prefix(key_prefix).await
}
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabaseTransactionOpsCore for &mut T
where
T: IDatabaseTransactionOpsCore + ?Sized,
{
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_insert_bytes(key, value).await
}
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_get_bytes(key).await
}
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
(**self).raw_remove_entry(key).await
}
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
(**self).raw_find_by_prefix(key_prefix).await
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>> {
(**self)
.raw_find_by_prefix_sorted_descending(key_prefix)
.await
}
async fn raw_find_by_range(&mut self, range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
(**self).raw_find_by_range(range).await
}
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
(**self).raw_remove_by_prefix(key_prefix).await
}
}
#[apply(async_trait_maybe_send!)]
pub trait IDatabaseTransactionOps: IDatabaseTransactionOpsCore + MaybeSend {
async fn set_tx_savepoint(&mut self) -> Result<()>;
async fn rollback_tx_to_savepoint(&mut self) -> Result<()>;
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabaseTransactionOps for Box<T>
where
T: IDatabaseTransactionOps + ?Sized,
{
async fn set_tx_savepoint(&mut self) -> Result<()> {
(**self).set_tx_savepoint().await
}
async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
(**self).rollback_tx_to_savepoint().await
}
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabaseTransactionOps for &mut T
where
T: IDatabaseTransactionOps + ?Sized,
{
async fn set_tx_savepoint(&mut self) -> Result<()> {
(**self).set_tx_savepoint().await
}
async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
(**self).rollback_tx_to_savepoint().await
}
}
#[apply(async_trait_maybe_send!)]
pub trait IDatabaseTransactionOpsCoreTyped<'a> {
async fn get_value<K>(&mut self, key: &K) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync;
async fn insert_entry<K>(&mut self, key: &K, value: &K::Value) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync;
async fn insert_new_entry<K>(&mut self, key: &K, value: &K::Value)
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync;
async fn find_by_range<K>(
&mut self,
key_range: Range<K>,
) -> Pin<Box<maybe_add_send!(dyn Stream<Item = (K, K::Value)> + '_)>>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync;
async fn find_by_prefix<KP>(
&mut self,
key_prefix: &KP,
) -> Pin<
Box<
maybe_add_send!(
dyn Stream<
Item = (
KP::Record,
<<KP as DatabaseLookup>::Record as DatabaseRecord>::Value,
),
> + '_
),
>,
>
where
KP: DatabaseLookup + MaybeSend + MaybeSync,
KP::Record: DatabaseKey;
async fn find_by_prefix_sorted_descending<KP>(
&mut self,
key_prefix: &KP,
) -> Pin<
Box<
maybe_add_send!(
dyn Stream<
Item = (
KP::Record,
<<KP as DatabaseLookup>::Record as DatabaseRecord>::Value,
),
> + '_
),
>,
>
where
KP: DatabaseLookup + MaybeSend + MaybeSync,
KP::Record: DatabaseKey;
async fn remove_entry<K>(&mut self, key: &K) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync;
async fn remove_by_prefix<KP>(&mut self, key_prefix: &KP)
where
KP: DatabaseLookup + MaybeSend + MaybeSync;
}
#[apply(async_trait_maybe_send!)]
impl<'a, T> IDatabaseTransactionOpsCoreTyped<'a> for T
where
T: IDatabaseTransactionOpsCore + WithDecoders,
{
async fn get_value<K>(&mut self, key: &K) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
{
let key_bytes = key.to_bytes();
let raw = self
.raw_get_bytes(&key_bytes)
.await
.expect("Unrecoverable error occurred while reading and entry from the database");
raw.map(|value_bytes| {
decode_value_expect::<K::Value>(&value_bytes, self.decoders(), &key_bytes)
})
}
async fn insert_entry<K>(&mut self, key: &K, value: &K::Value) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync,
{
let key_bytes = key.to_bytes();
self.raw_insert_bytes(&key_bytes, &value.to_bytes())
.await
.expect("Unrecoverable error occurred while inserting entry into the database")
.map(|value_bytes| {
decode_value_expect::<K::Value>(&value_bytes, self.decoders(), &key_bytes)
})
}
async fn insert_new_entry<K>(&mut self, key: &K, value: &K::Value)
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync,
{
if let Some(prev) = self.insert_entry(key, value).await {
panic!(
"Database overwriting element when expecting insertion of new entry. Key: {key:?} Prev Value: {prev:?}"
);
}
}
async fn find_by_range<K>(
&mut self,
key_range: Range<K>,
) -> Pin<Box<maybe_add_send!(dyn Stream<Item = (K, K::Value)> + '_)>>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
K::Value: MaybeSend + MaybeSync,
{
let decoders = self.decoders().clone();
Box::pin(
self.raw_find_by_range(Range {
start: &key_range.start.to_bytes(),
end: &key_range.end.to_bytes(),
})
.await
.expect("Unrecoverable error occurred while listing entries from the database")
.map(move |(key_bytes, value_bytes)| {
let key = decode_key_expect(&key_bytes, &decoders);
let value = decode_value_expect(&value_bytes, &decoders, &key_bytes);
(key, value)
}),
)
}
async fn find_by_prefix<KP>(
&mut self,
key_prefix: &KP,
) -> Pin<
Box<
maybe_add_send!(
dyn Stream<
Item = (
KP::Record,
<<KP as DatabaseLookup>::Record as DatabaseRecord>::Value,
),
> + '_
),
>,
>
where
KP: DatabaseLookup + MaybeSend + MaybeSync,
KP::Record: DatabaseKey,
{
let decoders = self.decoders().clone();
Box::pin(
self.raw_find_by_prefix(&key_prefix.to_bytes())
.await
.expect("Unrecoverable error occurred while listing entries from the database")
.map(move |(key_bytes, value_bytes)| {
let key = decode_key_expect(&key_bytes, &decoders);
let value = decode_value_expect(&value_bytes, &decoders, &key_bytes);
(key, value)
}),
)
}
async fn find_by_prefix_sorted_descending<KP>(
&mut self,
key_prefix: &KP,
) -> Pin<
Box<
maybe_add_send!(
dyn Stream<
Item = (
KP::Record,
<<KP as DatabaseLookup>::Record as DatabaseRecord>::Value,
),
> + '_
),
>,
>
where
KP: DatabaseLookup + MaybeSend + MaybeSync,
KP::Record: DatabaseKey,
{
let decoders = self.decoders().clone();
Box::pin(
self.raw_find_by_prefix_sorted_descending(&key_prefix.to_bytes())
.await
.expect("Unrecoverable error occurred while listing entries from the database")
.map(move |(key_bytes, value_bytes)| {
let key = decode_key_expect(&key_bytes, &decoders);
let value = decode_value_expect(&value_bytes, &decoders, &key_bytes);
(key, value)
}),
)
}
async fn remove_entry<K>(&mut self, key: &K) -> Option<K::Value>
where
K: DatabaseKey + DatabaseRecord + MaybeSend + MaybeSync,
{
let key_bytes = key.to_bytes();
self.raw_remove_entry(&key_bytes)
.await
.expect("Unrecoverable error occurred while inserting removing entry from the database")
.map(|value_bytes| {
decode_value_expect::<K::Value>(&value_bytes, self.decoders(), &key_bytes)
})
}
async fn remove_by_prefix<KP>(&mut self, key_prefix: &KP)
where
KP: DatabaseLookup + MaybeSend + MaybeSync,
{
self.raw_remove_by_prefix(&key_prefix.to_bytes())
.await
.expect("Unrecoverable error when removing entries from the database");
}
}
pub trait WithDecoders {
fn decoders(&self) -> &ModuleDecoderRegistry;
}
#[apply(async_trait_maybe_send!)]
pub trait IRawDatabaseTransaction: MaybeSend + IDatabaseTransactionOps {
async fn commit_tx(self) -> Result<()>;
}
#[apply(async_trait_maybe_send!)]
pub trait IDatabaseTransaction: MaybeSend + IDatabaseTransactionOps + fmt::Debug {
async fn commit_tx(&mut self) -> Result<()>;
fn is_global(&self) -> bool;
#[doc(hidden)]
fn global_dbtx(&mut self, access_token: GlobalDBTxAccessToken)
-> &mut dyn IDatabaseTransaction;
}
#[apply(async_trait_maybe_send!)]
impl<T> IDatabaseTransaction for Box<T>
where
T: IDatabaseTransaction + ?Sized,
{
async fn commit_tx(&mut self) -> Result<()> {
(**self).commit_tx().await
}
fn is_global(&self) -> bool {
(**self).is_global()
}
fn global_dbtx(
&mut self,
access_token: GlobalDBTxAccessToken,
) -> &mut dyn IDatabaseTransaction {
(**self).global_dbtx(access_token)
}
}
#[apply(async_trait_maybe_send!)]
impl<'a, T> IDatabaseTransaction for &'a mut T
where
T: IDatabaseTransaction + ?Sized,
{
async fn commit_tx(&mut self) -> Result<()> {
(**self).commit_tx().await
}
fn is_global(&self) -> bool {
(**self).is_global()
}
fn global_dbtx(&mut self, access_key: GlobalDBTxAccessToken) -> &mut dyn IDatabaseTransaction {
(**self).global_dbtx(access_key)
}
}
struct BaseDatabaseTransaction<Tx> {
raw: Option<Tx>,
notify_queue: Option<NotifyQueue>,
notifications: Arc<Notifications>,
}
impl<Tx> fmt::Debug for BaseDatabaseTransaction<Tx>
where
Tx: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"BaseDatabaseTransaction{{ raw={:?} }}",
self.raw
))
}
}
impl<Tx> BaseDatabaseTransaction<Tx>
where
Tx: IRawDatabaseTransaction,
{
fn new(dbtx: Tx, notifications: Arc<Notifications>) -> Self {
Self {
raw: Some(dbtx),
notifications,
notify_queue: Some(NotifyQueue::new()),
}
}
fn add_notification_key(&mut self, key: &[u8]) -> Result<()> {
self.notify_queue
.as_mut()
.context("can not call add_notification_key after commit")?
.add(&key);
Ok(())
}
}
#[apply(async_trait_maybe_send!)]
impl<Tx: IRawDatabaseTransaction> IDatabaseTransactionOpsCore for BaseDatabaseTransaction<Tx> {
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
self.add_notification_key(key)?;
self.raw
.as_mut()
.context("Cannot insert into already consumed transaction")?
.raw_insert_bytes(key, value)
.await
}
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.raw
.as_mut()
.context("Cannot retrieve from already consumed transaction")?
.raw_get_bytes(key)
.await
}
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.add_notification_key(key)?;
self.raw
.as_mut()
.context("Cannot remove from already consumed transaction")?
.raw_remove_entry(key)
.await
}
async fn raw_find_by_range(&mut self, key_range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
self.raw
.as_mut()
.context("Cannot retrieve from already consumed transaction")?
.raw_find_by_range(key_range)
.await
}
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
self.raw
.as_mut()
.context("Cannot retrieve from already consumed transaction")?
.raw_find_by_prefix(key_prefix)
.await
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>> {
self.raw
.as_mut()
.context("Cannot retrieve from already consumed transaction")?
.raw_find_by_prefix_sorted_descending(key_prefix)
.await
}
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
self.raw
.as_mut()
.context("Cannot remove from already consumed transaction")?
.raw_remove_by_prefix(key_prefix)
.await
}
}
#[apply(async_trait_maybe_send!)]
impl<Tx: IRawDatabaseTransaction> IDatabaseTransactionOps for BaseDatabaseTransaction<Tx> {
async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
self.raw
.as_mut()
.context("Cannot rollback to a savepoint on an already consumed transaction")?
.rollback_tx_to_savepoint()
.await?;
Ok(())
}
async fn set_tx_savepoint(&mut self) -> Result<()> {
self.raw
.as_mut()
.context("Cannot set a tx savepoint on an already consumed transaction")?
.set_tx_savepoint()
.await?;
Ok(())
}
}
#[apply(async_trait_maybe_send!)]
impl<Tx: IRawDatabaseTransaction + fmt::Debug> IDatabaseTransaction
for BaseDatabaseTransaction<Tx>
{
async fn commit_tx(&mut self) -> Result<()> {
self.raw
.take()
.context("Cannot commit an already committed transaction")?
.commit_tx()
.await?;
self.notifications.submit_queue(
&self
.notify_queue
.take()
.expect("commit must be called only once"),
);
Ok(())
}
fn is_global(&self) -> bool {
true
}
fn global_dbtx(
&mut self,
_access_token: GlobalDBTxAccessToken,
) -> &mut dyn IDatabaseTransaction {
panic!("Illegal to call global_dbtx on BaseDatabaseTransaction");
}
}
#[derive(Clone)]
struct CommitTracker {
is_committed: bool,
has_writes: bool,
ignore_uncommitted: bool,
}
impl Drop for CommitTracker {
fn drop(&mut self) {
if self.has_writes && !self.is_committed {
if self.ignore_uncommitted {
trace!(
target: LOG_DB,
"DatabaseTransaction has writes and has not called commit, but that's expected."
);
} else {
warn!(
target: LOG_DB,
location = ?backtrace::Backtrace::new(),
"DatabaseTransaction has writes and has not called commit."
);
}
}
}
}
enum MaybeRef<'a, T> {
Owned(T),
Borrowed(&'a mut T),
}
impl<'a, T> ops::Deref for MaybeRef<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeRef::Owned(o) => o,
MaybeRef::Borrowed(r) => r,
}
}
}
impl<'a, T> ops::DerefMut for MaybeRef<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
MaybeRef::Owned(o) => o,
MaybeRef::Borrowed(r) => r,
}
}
}
pub struct Committable;
pub struct NonCommittable;
pub struct DatabaseTransaction<'tx, Cap = NonCommittable> {
tx: Box<dyn IDatabaseTransaction + 'tx>,
decoders: ModuleDecoderRegistry,
commit_tracker: MaybeRef<'tx, CommitTracker>,
on_commit_hooks: MaybeRef<'tx, Vec<Box<maybe_add_send!(dyn FnOnce())>>>,
capability: marker::PhantomData<Cap>,
}
impl<'tx, Cap> fmt::Debug for DatabaseTransaction<'tx, Cap> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"DatabaseTransaction {{ tx: {:?}, decoders={:?} }}",
self.tx, self.decoders
))
}
}
impl<'tx, Cap> WithDecoders for DatabaseTransaction<'tx, Cap> {
fn decoders(&self) -> &ModuleDecoderRegistry {
&self.decoders
}
}
#[instrument(level = "trace", skip_all, fields(value_type = std::any::type_name::<V>()), err)]
fn decode_value<V: DatabaseValue>(
value_bytes: &[u8],
decoders: &ModuleDecoderRegistry,
) -> Result<V, DecodingError> {
trace!(
bytes = %AbbreviateHexBytes(value_bytes),
"decoding value",
);
V::from_bytes(value_bytes, decoders)
}
fn decode_value_expect<V: DatabaseValue>(
value_bytes: &[u8],
decoders: &ModuleDecoderRegistry,
key_bytes: &[u8],
) -> V {
decode_value(value_bytes, decoders).unwrap_or_else(|err| {
panic!(
"Unrecoverable decoding DatabaseValue as {}; err={}, key_bytes={}, val_bytes={}",
any::type_name::<V>(),
err,
AbbreviateHexBytes(key_bytes),
AbbreviateHexBytes(value_bytes),
)
})
}
fn decode_key_expect<K: DatabaseKey>(key_bytes: &[u8], decoders: &ModuleDecoderRegistry) -> K {
trace!(
bytes = %AbbreviateHexBytes(key_bytes),
"decoding key",
);
K::from_bytes(key_bytes, decoders).unwrap_or_else(|err| {
panic!(
"Unrecoverable decoding DatabaseKey as {}; err={}; bytes={}",
any::type_name::<K>(),
err,
AbbreviateHexBytes(key_bytes)
)
})
}
impl<'tx, Cap> DatabaseTransaction<'tx, Cap> {
pub fn into_nc(self) -> DatabaseTransaction<'tx, NonCommittable> {
DatabaseTransaction {
tx: self.tx,
decoders: self.decoders,
commit_tracker: self.commit_tracker,
on_commit_hooks: self.on_commit_hooks,
capability: PhantomData::<NonCommittable>,
}
}
pub fn to_ref_nc<'s, 'a>(&'s mut self) -> DatabaseTransaction<'a, NonCommittable>
where
's: 'a,
{
self.to_ref().into_nc()
}
pub fn with_prefix<'a: 'tx>(self, prefix: Vec<u8>) -> DatabaseTransaction<'a, Cap>
where
'tx: 'a,
{
DatabaseTransaction {
tx: Box::new(PrefixDatabaseTransaction {
inner: self.tx,
global_dbtx_access_token: None,
prefix,
}),
decoders: self.decoders,
commit_tracker: self.commit_tracker,
on_commit_hooks: self.on_commit_hooks,
capability: self.capability,
}
}
pub fn with_prefix_module_id<'a: 'tx>(
self,
module_instance_id: ModuleInstanceId,
) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken)
where
'tx: 'a,
{
let prefix = module_instance_id_to_byte_prefix(module_instance_id);
let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix);
(
DatabaseTransaction {
tx: Box::new(PrefixDatabaseTransaction {
inner: self.tx,
global_dbtx_access_token: Some(global_dbtx_access_token),
prefix,
}),
decoders: self.decoders,
commit_tracker: self.commit_tracker,
on_commit_hooks: self.on_commit_hooks,
capability: self.capability,
},
global_dbtx_access_token,
)
}
pub fn to_ref<'s, 'a>(&'s mut self) -> DatabaseTransaction<'a, Cap>
where
's: 'a,
{
let decoders = self.decoders.clone();
DatabaseTransaction {
tx: Box::new(&mut self.tx),
decoders,
commit_tracker: match self.commit_tracker {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
on_commit_hooks: match self.on_commit_hooks {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
capability: self.capability,
}
}
pub fn to_ref_with_prefix<'a>(&'a mut self, prefix: Vec<u8>) -> DatabaseTransaction<'a, Cap>
where
'tx: 'a,
{
DatabaseTransaction {
tx: Box::new(PrefixDatabaseTransaction {
inner: &mut self.tx,
global_dbtx_access_token: None,
prefix,
}),
decoders: self.decoders.clone(),
commit_tracker: match self.commit_tracker {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
on_commit_hooks: match self.on_commit_hooks {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
capability: self.capability,
}
}
pub fn to_ref_with_prefix_module_id<'a>(
&'a mut self,
module_instance_id: ModuleInstanceId,
) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken)
where
'tx: 'a,
{
let prefix = module_instance_id_to_byte_prefix(module_instance_id);
let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix);
(
DatabaseTransaction {
tx: Box::new(PrefixDatabaseTransaction {
inner: &mut self.tx,
global_dbtx_access_token: Some(global_dbtx_access_token),
prefix,
}),
decoders: self.decoders.clone(),
commit_tracker: match self.commit_tracker {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
on_commit_hooks: match self.on_commit_hooks {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
capability: self.capability,
},
global_dbtx_access_token,
)
}
pub fn is_global(&self) -> bool {
self.tx.is_global()
}
pub fn ensure_global(&self) -> Result<()> {
if !self.is_global() {
bail!("Database instance not global");
}
Ok(())
}
pub fn ensure_isolated(&self) -> Result<()> {
if self.is_global() {
bail!("Database instance not isolated");
}
Ok(())
}
pub fn ignore_uncommitted(&mut self) -> &mut Self {
self.commit_tracker.ignore_uncommitted = true;
self
}
pub fn warn_uncommitted(&mut self) -> &mut Self {
self.commit_tracker.ignore_uncommitted = false;
self
}
#[instrument(level = "trace", skip_all)]
pub fn on_commit(&mut self, f: maybe_add_send!(impl FnOnce() + 'static)) {
self.on_commit_hooks.push(Box::new(f));
}
pub fn global_dbtx<'a>(
&'a mut self,
access_token: GlobalDBTxAccessToken,
) -> DatabaseTransaction<'a, Cap>
where
'tx: 'a,
{
let decoders = self.decoders.clone();
DatabaseTransaction {
tx: Box::new(self.tx.global_dbtx(access_token)),
decoders,
commit_tracker: match self.commit_tracker {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
on_commit_hooks: match self.on_commit_hooks {
MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o),
MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b),
},
capability: self.capability,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct GlobalDBTxAccessToken(u32);
impl GlobalDBTxAccessToken {
fn from_prefix(prefix: &[u8]) -> Self {
Self(prefix.iter().fold(0, |acc, b| acc + u32::from(*b)) + 513)
}
}
impl<'tx> DatabaseTransaction<'tx, Committable> {
pub fn new(dbtx: Box<dyn IDatabaseTransaction + 'tx>, decoders: ModuleDecoderRegistry) -> Self {
Self {
tx: dbtx,
decoders,
commit_tracker: MaybeRef::Owned(CommitTracker {
is_committed: false,
has_writes: false,
ignore_uncommitted: false,
}),
on_commit_hooks: MaybeRef::Owned(vec![]),
capability: PhantomData,
}
}
pub async fn commit_tx_result(mut self) -> Result<()> {
self.commit_tracker.is_committed = true;
let commit_result = self.tx.commit_tx().await;
if commit_result.is_ok() {
for hook in self.on_commit_hooks.deref_mut().drain(..) {
hook();
}
}
commit_result
}
pub async fn commit_tx(mut self) {
self.commit_tracker.is_committed = true;
self.commit_tx_result()
.await
.expect("Unrecoverable error occurred while committing to the database.");
}
}
#[apply(async_trait_maybe_send!)]
impl<'a, Cap> IDatabaseTransactionOpsCore for DatabaseTransaction<'a, Cap>
where
Cap: Send,
{
async fn raw_insert_bytes(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
self.commit_tracker.has_writes = true;
self.tx.raw_insert_bytes(key, value).await
}
async fn raw_get_bytes(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.tx.raw_get_bytes(key).await
}
async fn raw_remove_entry(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.tx.raw_remove_entry(key).await
}
async fn raw_find_by_range(&mut self, key_range: Range<&[u8]>) -> Result<PrefixStream<'_>> {
self.tx.raw_find_by_range(key_range).await
}
async fn raw_find_by_prefix(&mut self, key_prefix: &[u8]) -> Result<PrefixStream<'_>> {
self.tx.raw_find_by_prefix(key_prefix).await
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
key_prefix: &[u8],
) -> Result<PrefixStream<'_>> {
self.tx
.raw_find_by_prefix_sorted_descending(key_prefix)
.await
}
async fn raw_remove_by_prefix(&mut self, key_prefix: &[u8]) -> Result<()> {
self.commit_tracker.has_writes = true;
self.tx.raw_remove_by_prefix(key_prefix).await
}
}
#[apply(async_trait_maybe_send!)]
impl<'a> IDatabaseTransactionOps for DatabaseTransaction<'a, Committable> {
async fn set_tx_savepoint(&mut self) -> Result<()> {
self.tx.set_tx_savepoint().await
}
async fn rollback_tx_to_savepoint(&mut self) -> Result<()> {
self.tx.rollback_tx_to_savepoint().await
}
}
impl<T> DatabaseKeyPrefix for T
where
T: DatabaseLookup + crate::encoding::Encodable + Debug,
{
fn to_bytes(&self) -> Vec<u8> {
let mut data = vec![<Self as DatabaseLookup>::Record::DB_PREFIX];
data.append(&mut self.consensus_encode_to_vec());
data
}
}
impl<T> DatabaseKey for T
where
T: DatabaseRecord + crate::encoding::Decodable + Sized,
{
const NOTIFY_ON_MODIFY: bool = <T as DatabaseRecord>::NOTIFY_ON_MODIFY;
fn from_bytes(data: &[u8], modules: &ModuleDecoderRegistry) -> Result<Self, DecodingError> {
if data.is_empty() {
return Err(DecodingError::wrong_length(1, 0));
}
if data[0] != Self::DB_PREFIX {
return Err(DecodingError::wrong_prefix(Self::DB_PREFIX, data[0]));
}
<Self as crate::encoding::Decodable>::consensus_decode_whole(&data[1..], modules)
.map_err(|decode_error| DecodingError::Other(decode_error.0))
}
}
impl<T> DatabaseValue for T
where
T: Debug + Encodable + Decodable,
{
fn from_bytes(data: &[u8], modules: &ModuleDecoderRegistry) -> Result<Self, DecodingError> {
T::consensus_decode_whole(data, modules).map_err(|e| DecodingError::Other(e.0))
}
fn to_bytes(&self) -> Vec<u8> {
self.consensus_encode_to_vec()
}
}
#[macro_export]
macro_rules! impl_db_record {
(key = $key:ty, value = $val:ty, db_prefix = $db_prefix:expr $(, notify_on_modify = $notify:tt)? $(,)?) => {
impl $crate::db::DatabaseRecord for $key {
const DB_PREFIX: u8 = $db_prefix as u8;
$(const NOTIFY_ON_MODIFY: bool = $notify;)?
type Key = Self;
type Value = $val;
}
$(
impl_db_record! {
@impl_notify_marker key = $key, notify_on_modify = $notify
}
)?
};
(@impl_notify_marker key = $key:ty, notify_on_modify = true) => {
impl $crate::db::DatabaseKeyWithNotify for $key {}
};
(@impl_notify_marker key = $key:ty, notify_on_modify = false) => {};
}
#[macro_export]
macro_rules! impl_db_lookup{
(key = $key:ty $(, query_prefix = $query_prefix:ty)* $(,)?) => {
$(
impl $crate::db::DatabaseLookup for $query_prefix {
type Record = $key;
}
)*
};
}
#[derive(Debug, Encodable, Decodable, Serialize)]
pub struct DatabaseVersionKeyV0;
#[derive(Debug, Encodable, Decodable, Serialize)]
pub struct DatabaseVersionKey(pub ModuleInstanceId);
#[derive(Debug, Encodable, Decodable, Serialize, Clone, PartialOrd, Ord, PartialEq, Eq, Copy)]
pub struct DatabaseVersion(pub u64);
impl_db_record!(
key = DatabaseVersionKeyV0,
value = DatabaseVersion,
db_prefix = DbKeyPrefix::DatabaseVersion
);
impl_db_record!(
key = DatabaseVersionKey,
value = DatabaseVersion,
db_prefix = DbKeyPrefix::DatabaseVersion
);
impl std::fmt::Display for DatabaseVersion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl DatabaseVersion {
pub fn increment(&self) -> Self {
Self(self.0 + 1)
}
}
impl std::fmt::Display for DbKeyPrefix {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[repr(u8)]
#[derive(Clone, EnumIter, Debug)]
pub enum DbKeyPrefix {
DatabaseVersion = 0x50,
ClientBackup = 0x51,
}
#[derive(Debug, Error)]
pub enum DecodingError {
#[error("Key had a wrong prefix, expected {expected} but got {found}")]
WrongPrefix { expected: u8, found: u8 },
#[error("Key had a wrong length, expected {expected} but got {found}")]
WrongLength { expected: usize, found: usize },
#[error("Other decoding error: {0:?}")]
Other(anyhow::Error),
}
impl DecodingError {
pub fn other<E: Error + Send + Sync + 'static>(error: E) -> Self {
Self::Other(anyhow::Error::from(error))
}
pub fn wrong_prefix(expected: u8, found: u8) -> Self {
Self::WrongPrefix { expected, found }
}
pub fn wrong_length(expected: usize, found: usize) -> Self {
Self::WrongLength { expected, found }
}
}
#[macro_export]
macro_rules! push_db_pair_items {
($dbtx:ident, $prefix_type:expr, $key_type:ty, $value_type:ty, $map:ident, $key_literal:literal) => {
let db_items =
$crate::db::IDatabaseTransactionOpsCoreTyped::find_by_prefix($dbtx, &$prefix_type)
.await
.map(|(key, val)| {
(
$crate::encoding::Encodable::consensus_encode_to_hex(&key),
val,
)
})
.collect::<BTreeMap<String, $value_type>>()
.await;
$map.insert($key_literal.to_string(), Box::new(db_items));
};
}
#[macro_export]
macro_rules! push_db_key_items {
($dbtx:ident, $prefix_type:expr, $key_type:ty, $map:ident, $key_literal:literal) => {
let db_items =
$crate::db::IDatabaseTransactionOpsCoreTyped::find_by_prefix($dbtx, &$prefix_type)
.await
.map(|(key, _)| key)
.collect::<Vec<$key_type>>()
.await;
$map.insert($key_literal.to_string(), Box::new(db_items));
};
}
pub type CoreMigrationFn = for<'tx> fn(
MigrationContext<'tx>,
) -> Pin<
Box<maybe_add_send!(dyn futures::Future<Output = anyhow::Result<()>> + 'tx)>,
>;
pub fn get_current_database_version<F>(
migrations: &BTreeMap<DatabaseVersion, F>,
) -> DatabaseVersion {
let versions = migrations.keys().copied().collect::<Vec<_>>();
if !versions
.windows(2)
.all(|window| window[0].increment() == window[1])
{
panic!("Database Migrations are not defined contiguously");
}
versions
.last()
.map_or(DatabaseVersion(0), DatabaseVersion::increment)
}
pub async fn apply_migrations_server(
db: &Database,
kind: String,
migrations: BTreeMap<DatabaseVersion, CoreMigrationFn>,
) -> Result<(), anyhow::Error> {
apply_migrations(db, kind, migrations, None, None).await
}
pub async fn apply_migrations(
db: &Database,
kind: String,
migrations: BTreeMap<DatabaseVersion, CoreMigrationFn>,
module_instance_id: Option<ModuleInstanceId>,
external_prefixes_above: Option<u8>,
) -> Result<(), anyhow::Error> {
let mut dbtx = db.begin_transaction_nc().await;
let is_new_db = dbtx
.raw_find_by_prefix(&[])
.await?
.filter(|(key, _v)| {
std::future::ready(
external_prefixes_above.map_or(true, |external_prefixes_above| {
!key.is_empty() && key[0] < external_prefixes_above
}),
)
})
.next()
.await
.is_none();
let target_db_version = get_current_database_version(&migrations);
create_database_version(
db,
target_db_version,
module_instance_id,
kind.clone(),
is_new_db,
)
.await?;
let mut global_dbtx = db.begin_transaction().await;
let module_instance_id_key = module_instance_id_or_global(module_instance_id);
let disk_version = global_dbtx
.get_value(&DatabaseVersionKey(module_instance_id_key))
.await;
let db_version = if let Some(disk_version) = disk_version {
let mut current_db_version = disk_version;
if current_db_version > target_db_version {
return Err(anyhow::anyhow!(format!(
"On disk database version {current_db_version} for module {kind} was higher than the code database version {target_db_version}."
)));
}
while current_db_version < target_db_version {
if let Some(migration) = migrations.get(¤t_db_version) {
info!(target: LOG_DB, ?kind, ?current_db_version, ?target_db_version, "Migrating module...");
migration(MigrationContext {
dbtx: global_dbtx.to_ref_nc(),
module_instance_id,
})
.await?;
} else {
warn!(target: LOG_DB, ?current_db_version, "Missing server db migration");
}
current_db_version = current_db_version.increment();
global_dbtx
.insert_entry(
&DatabaseVersionKey(module_instance_id_key),
¤t_db_version,
)
.await;
}
current_db_version
} else {
target_db_version
};
global_dbtx.commit_tx_result().await?;
debug!(target: LOG_DB, ?kind, ?db_version, "DB Version");
Ok(())
}
pub async fn create_database_version(
db: &Database,
target_db_version: DatabaseVersion,
module_instance_id: Option<ModuleInstanceId>,
kind: String,
is_new_db: bool,
) -> Result<(), anyhow::Error> {
let key_module_instance_id = module_instance_id_or_global(module_instance_id);
let mut global_dbtx = db.begin_transaction().await;
if global_dbtx
.get_value(&DatabaseVersionKey(key_module_instance_id))
.await
.is_none()
{
let current_version_in_module = if let Some(module_instance_id) = module_instance_id {
remove_current_db_version_if_exists(
&mut global_dbtx
.to_ref_with_prefix_module_id(module_instance_id)
.0
.into_nc(),
is_new_db,
target_db_version,
)
.await
} else {
remove_current_db_version_if_exists(
&mut global_dbtx.to_ref().into_nc(),
is_new_db,
target_db_version,
)
.await
};
info!(target: LOG_DB, ?kind, ?current_version_in_module, ?target_db_version, ?is_new_db, "Creating DatabaseVersionKey...");
global_dbtx
.insert_new_entry(
&DatabaseVersionKey(key_module_instance_id),
¤t_version_in_module,
)
.await;
global_dbtx.commit_tx_result().await?;
}
Ok(())
}
async fn remove_current_db_version_if_exists(
version_dbtx: &mut DatabaseTransaction<'_>,
is_new_db: bool,
target_db_version: DatabaseVersion,
) -> DatabaseVersion {
let current_version_in_module = version_dbtx.remove_entry(&DatabaseVersionKeyV0).await;
match current_version_in_module {
Some(database_version) => database_version,
None if is_new_db => target_db_version,
None => DatabaseVersion(0),
}
}
fn module_instance_id_or_global(module_instance_id: Option<ModuleInstanceId>) -> ModuleInstanceId {
module_instance_id.map_or_else(
|| MODULE_GLOBAL_PREFIX.into(),
|module_instance_id| module_instance_id,
)
}
pub struct MigrationContext<'tx> {
dbtx: DatabaseTransaction<'tx>,
module_instance_id: Option<ModuleInstanceId>,
}
impl<'tx> MigrationContext<'tx> {
pub fn dbtx(&mut self) -> DatabaseTransaction {
if let Some(module_instance_id) = self.module_instance_id {
self.dbtx.to_ref_with_prefix_module_id(module_instance_id).0
} else {
self.dbtx.to_ref_nc()
}
}
pub fn module_instance_id(&self) -> Option<ModuleInstanceId> {
self.module_instance_id
}
#[doc(hidden)]
pub fn __global_dbtx(&mut self) -> &mut DatabaseTransaction<'tx> {
&mut self.dbtx
}
}
#[allow(unused_imports)]
mod test_utils {
use std::collections::BTreeMap;
use std::time::Duration;
use fedimint_core::db::MigrationContext;
use futures::future::ready;
use futures::{Future, FutureExt, StreamExt};
use rand::Rng;
use tokio::join;
use super::{
apply_migrations, CoreMigrationFn, Database, DatabaseTransaction, DatabaseVersion,
DatabaseVersionKey, DatabaseVersionKeyV0,
};
use crate::core::ModuleKind;
use crate::db::mem_impl::MemDatabase;
use crate::db::{
IDatabaseTransactionOps, IDatabaseTransactionOpsCoreTyped, MODULE_GLOBAL_PREFIX,
};
use crate::encoding::{Decodable, Encodable};
use crate::module::registry::ModuleDecoderRegistry;
pub async fn future_returns_shortly<F: Future>(fut: F) -> Option<F::Output> {
crate::runtime::timeout(Duration::from_millis(10), fut)
.await
.ok()
}
#[repr(u8)]
#[derive(Clone)]
pub enum TestDbKeyPrefix {
Test = 0x42,
AltTest = 0x43,
PercentTestKey = 0x25,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Encodable, Decodable)]
pub(super) struct TestKey(pub u64);
#[derive(Debug, Encodable, Decodable)]
struct DbPrefixTestPrefix;
impl_db_record!(
key = TestKey,
value = TestVal,
db_prefix = TestDbKeyPrefix::Test,
notify_on_modify = true,
);
impl_db_lookup!(key = TestKey, query_prefix = DbPrefixTestPrefix);
#[derive(Debug, Encodable, Decodable)]
struct TestKeyV0(u64, u64);
#[derive(Debug, Encodable, Decodable)]
struct DbPrefixTestPrefixV0;
impl_db_record!(
key = TestKeyV0,
value = TestVal,
db_prefix = TestDbKeyPrefix::Test,
);
impl_db_lookup!(key = TestKeyV0, query_prefix = DbPrefixTestPrefixV0);
#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Encodable, Decodable)]
struct AltTestKey(u64);
#[derive(Debug, Encodable, Decodable)]
struct AltDbPrefixTestPrefix;
impl_db_record!(
key = AltTestKey,
value = TestVal,
db_prefix = TestDbKeyPrefix::AltTest,
);
impl_db_lookup!(key = AltTestKey, query_prefix = AltDbPrefixTestPrefix);
#[derive(Debug, Encodable, Decodable)]
struct PercentTestKey(u64);
#[derive(Debug, Encodable, Decodable)]
struct PercentPrefixTestPrefix;
impl_db_record!(
key = PercentTestKey,
value = TestVal,
db_prefix = TestDbKeyPrefix::PercentTestKey,
);
impl_db_lookup!(key = PercentTestKey, query_prefix = PercentPrefixTestPrefix);
#[derive(Debug, Encodable, Decodable, Eq, PartialEq, PartialOrd, Ord)]
pub(super) struct TestVal(pub u64);
const TEST_MODULE_PREFIX: u16 = 1;
const ALT_MODULE_PREFIX: u16 = 2;
pub async fn verify_insert_elements(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert!(dbtx.insert_entry(&TestKey(1), &TestVal(2)).await.is_none());
assert!(dbtx.insert_entry(&TestKey(2), &TestVal(3)).await.is_none());
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
assert_eq!(dbtx.get_value(&TestKey(1)).await, Some(TestVal(2)));
assert_eq!(dbtx.get_value(&TestKey(2)).await, Some(TestVal(3)));
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
assert_eq!(
dbtx.insert_entry(&TestKey(1), &TestVal(4)).await,
Some(TestVal(2))
);
assert_eq!(
dbtx.insert_entry(&TestKey(2), &TestVal(5)).await,
Some(TestVal(3))
);
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
assert_eq!(dbtx.get_value(&TestKey(1)).await, Some(TestVal(4)));
assert_eq!(dbtx.get_value(&TestKey(2)).await, Some(TestVal(5)));
dbtx.commit_tx().await;
}
pub async fn verify_remove_nonexisting(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert_eq!(dbtx.get_value(&TestKey(1)).await, None);
let removed = dbtx.remove_entry(&TestKey(1)).await;
assert!(removed.is_none());
dbtx.commit_tx().await;
}
pub async fn verify_remove_existing(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert!(dbtx.insert_entry(&TestKey(1), &TestVal(2)).await.is_none());
assert_eq!(dbtx.get_value(&TestKey(1)).await, Some(TestVal(2)));
let removed = dbtx.remove_entry(&TestKey(1)).await;
assert_eq!(removed, Some(TestVal(2)));
assert_eq!(dbtx.get_value(&TestKey(1)).await, None);
dbtx.commit_tx().await;
}
pub async fn verify_read_own_writes(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert!(dbtx.insert_entry(&TestKey(1), &TestVal(2)).await.is_none());
assert_eq!(dbtx.get_value(&TestKey(1)).await, Some(TestVal(2)));
dbtx.commit_tx().await;
}
pub async fn verify_prevent_dirty_reads(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert!(dbtx.insert_entry(&TestKey(1), &TestVal(2)).await.is_none());
let mut dbtx2 = db.begin_transaction().await;
assert_eq!(dbtx2.get_value(&TestKey(1)).await, None);
dbtx.commit_tx().await;
}
pub async fn verify_find_by_range(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(55), &TestVal(9999)).await;
dbtx.insert_entry(&TestKey(54), &TestVal(8888)).await;
dbtx.insert_entry(&TestKey(56), &TestVal(7777)).await;
dbtx.insert_entry(&AltTestKey(55), &TestVal(7777)).await;
dbtx.insert_entry(&AltTestKey(54), &TestVal(6666)).await;
{
let mut module_dbtx = dbtx.to_ref_with_prefix_module_id(2).0;
module_dbtx
.insert_entry(&TestKey(300), &TestVal(3000))
.await;
}
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction_nc().await;
let returned_keys = dbtx
.find_by_range(TestKey(55)..TestKey(56))
.await
.collect::<Vec<_>>()
.await;
let expected = vec![(TestKey(55), TestVal(9999))];
assert_eq!(returned_keys, expected);
let returned_keys = dbtx
.find_by_range(TestKey(54)..TestKey(56))
.await
.collect::<Vec<_>>()
.await;
let expected = vec![(TestKey(54), TestVal(8888)), (TestKey(55), TestVal(9999))];
assert_eq!(returned_keys, expected);
let returned_keys = dbtx
.find_by_range(TestKey(54)..TestKey(57))
.await
.collect::<Vec<_>>()
.await;
let expected = vec![
(TestKey(54), TestVal(8888)),
(TestKey(55), TestVal(9999)),
(TestKey(56), TestVal(7777)),
];
assert_eq!(returned_keys, expected);
let mut module_dbtx = dbtx.with_prefix_module_id(2).0;
let test_range = module_dbtx
.find_by_range(TestKey(300)..TestKey(301))
.await
.collect::<Vec<_>>()
.await;
assert!(test_range.len() == 1);
}
pub async fn verify_find_by_prefix(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(55), &TestVal(9999)).await;
dbtx.insert_entry(&TestKey(54), &TestVal(8888)).await;
dbtx.insert_entry(&AltTestKey(55), &TestVal(7777)).await;
dbtx.insert_entry(&AltTestKey(54), &TestVal(6666)).await;
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.collect::<Vec<_>>()
.await;
let expected = vec![(TestKey(54), TestVal(8888)), (TestKey(55), TestVal(9999))];
assert_eq!(returned_keys, expected);
let reversed = dbtx
.find_by_prefix_sorted_descending(&DbPrefixTestPrefix)
.await
.collect::<Vec<_>>()
.await;
let mut reversed_expected = expected;
reversed_expected.reverse();
assert_eq!(reversed, reversed_expected);
let returned_keys = dbtx
.find_by_prefix(&AltDbPrefixTestPrefix)
.await
.collect::<Vec<_>>()
.await;
let expected = vec![
(AltTestKey(54), TestVal(6666)),
(AltTestKey(55), TestVal(7777)),
];
assert_eq!(returned_keys, expected);
let reversed = dbtx
.find_by_prefix_sorted_descending(&AltDbPrefixTestPrefix)
.await
.collect::<Vec<_>>()
.await;
let mut reversed_expected = expected;
reversed_expected.reverse();
assert_eq!(reversed, reversed_expected);
}
pub async fn verify_commit(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert!(dbtx.insert_entry(&TestKey(1), &TestVal(2)).await.is_none());
dbtx.commit_tx().await;
let mut dbtx2 = db.begin_transaction().await;
assert_eq!(dbtx2.get_value(&TestKey(1)).await, Some(TestVal(2)));
}
pub async fn verify_rollback_to_savepoint(db: Database) {
let mut dbtx_rollback = db.begin_transaction().await;
dbtx_rollback
.insert_entry(&TestKey(20), &TestVal(2000))
.await;
dbtx_rollback
.set_tx_savepoint()
.await
.expect("Error setting transaction savepoint");
dbtx_rollback
.insert_entry(&TestKey(21), &TestVal(2001))
.await;
assert_eq!(
dbtx_rollback.get_value(&TestKey(20)).await,
Some(TestVal(2000))
);
assert_eq!(
dbtx_rollback.get_value(&TestKey(21)).await,
Some(TestVal(2001))
);
dbtx_rollback
.rollback_tx_to_savepoint()
.await
.expect("Error setting transaction savepoint");
assert_eq!(
dbtx_rollback.get_value(&TestKey(20)).await,
Some(TestVal(2000))
);
assert_eq!(dbtx_rollback.get_value(&TestKey(21)).await, None);
dbtx_rollback.commit_tx().await;
}
pub async fn verify_prevent_nonrepeatable_reads(db: Database) {
let mut dbtx = db.begin_transaction().await;
assert_eq!(dbtx.get_value(&TestKey(100)).await, None);
let mut dbtx2 = db.begin_transaction().await;
dbtx2.insert_entry(&TestKey(100), &TestVal(101)).await;
assert_eq!(dbtx.get_value(&TestKey(100)).await, None);
dbtx2.commit_tx().await;
assert_eq!(dbtx.get_value(&TestKey(100)).await, None);
let expected_keys = 0;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
if key == TestKey(100) {
assert!(value.eq(&TestVal(101)));
}
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
}
pub async fn verify_snapshot_isolation(db: Database) {
async fn random_yield() {
let times = if rand::thread_rng().gen_bool(0.5) {
0
} else {
10
};
for _ in 0..times {
tokio::task::yield_now().await;
}
}
for i in 0..1000 {
let base_key = i * 2;
let tx_accepted_key = base_key;
let spent_input_key = base_key + 1;
join!(
async {
random_yield().await;
let mut dbtx = db.begin_transaction().await;
random_yield().await;
let a = dbtx.get_value(&TestKey(tx_accepted_key)).await;
random_yield().await;
let s = match i % 5 {
0 => dbtx.get_value(&TestKey(spent_input_key)).await,
1 => dbtx.remove_entry(&TestKey(spent_input_key)).await,
2 => {
dbtx.insert_entry(&TestKey(spent_input_key), &TestVal(200))
.await
}
3 => {
dbtx.find_by_prefix(&DbPrefixTestPrefix)
.await
.filter(|(k, _v)| ready(k == &TestKey(spent_input_key)))
.map(|(_k, v)| v)
.next()
.await
}
4 => {
dbtx.find_by_prefix_sorted_descending(&DbPrefixTestPrefix)
.await
.filter(|(k, _v)| ready(k == &TestKey(spent_input_key)))
.map(|(_k, v)| v)
.next()
.await
}
_ => {
panic!("woot?");
}
};
match (a, s) {
(None, None) | (Some(_), Some(_)) => {}
(None, Some(_)) => panic!("none some?! {i}"),
(Some(_), None) => panic!("some none?! {i}"),
}
},
async {
random_yield().await;
let mut dbtx = db.begin_transaction().await;
random_yield().await;
assert_eq!(dbtx.get_value(&TestKey(tx_accepted_key)).await, None);
random_yield().await;
assert_eq!(
dbtx.insert_entry(&TestKey(spent_input_key), &TestVal(100))
.await,
None
);
random_yield().await;
assert_eq!(
dbtx.insert_entry(&TestKey(tx_accepted_key), &TestVal(100))
.await,
None
);
random_yield().await;
dbtx.commit_tx().await;
}
);
}
}
pub async fn verify_phantom_entry(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(100), &TestVal(101)).await;
dbtx.insert_entry(&TestKey(101), &TestVal(102)).await;
dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
let expected_keys = 2;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
match key {
TestKey(100) => {
assert!(value.eq(&TestVal(101)));
}
TestKey(101) => {
assert!(value.eq(&TestVal(102)));
}
_ => {}
};
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
let mut dbtx2 = db.begin_transaction().await;
dbtx2.insert_entry(&TestKey(102), &TestVal(103)).await;
dbtx2.commit_tx().await;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
match key {
TestKey(100) => {
assert!(value.eq(&TestVal(101)));
}
TestKey(101) => {
assert!(value.eq(&TestVal(102)));
}
_ => {}
};
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
}
pub async fn expect_write_conflict(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(100), &TestVal(101)).await;
dbtx.commit_tx().await;
let mut dbtx2 = db.begin_transaction().await;
let mut dbtx3 = db.begin_transaction().await;
dbtx2.insert_entry(&TestKey(100), &TestVal(102)).await;
dbtx3.insert_entry(&TestKey(100), &TestVal(103)).await;
dbtx2.commit_tx().await;
dbtx3.commit_tx_result().await.expect_err("Expecting an error to be returned because this transaction is in a write-write conflict with dbtx");
}
pub async fn verify_string_prefix(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&PercentTestKey(100), &TestVal(101)).await;
assert_eq!(
dbtx.get_value(&PercentTestKey(100)).await,
Some(TestVal(101))
);
dbtx.insert_entry(&PercentTestKey(101), &TestVal(100)).await;
dbtx.insert_entry(&PercentTestKey(101), &TestVal(100)).await;
dbtx.insert_entry(&PercentTestKey(101), &TestVal(100)).await;
dbtx.insert_entry(&TestKey(101), &TestVal(100)).await;
let expected_keys = 4;
let returned_keys = dbtx
.find_by_prefix(&PercentPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
if matches!(key, PercentTestKey(101)) {
assert!(value.eq(&TestVal(100)));
}
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
}
pub async fn verify_remove_by_prefix(db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(100), &TestVal(101)).await;
dbtx.insert_entry(&TestKey(101), &TestVal(102)).await;
dbtx.commit_tx().await;
let mut remove_dbtx = db.begin_transaction().await;
remove_dbtx.remove_by_prefix(&DbPrefixTestPrefix).await;
remove_dbtx.commit_tx().await;
let mut dbtx = db.begin_transaction().await;
let expected_keys = 0;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
match key {
TestKey(100) => {
assert!(value.eq(&TestVal(101)));
}
TestKey(101) => {
assert!(value.eq(&TestVal(102)));
}
_ => {}
};
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
}
pub async fn verify_module_db(db: Database, module_db: Database) {
let mut dbtx = db.begin_transaction().await;
dbtx.insert_entry(&TestKey(100), &TestVal(101)).await;
dbtx.insert_entry(&TestKey(101), &TestVal(102)).await;
dbtx.commit_tx().await;
let mut module_dbtx = module_db.begin_transaction().await;
assert_eq!(module_dbtx.get_value(&TestKey(100)).await, None);
assert_eq!(module_dbtx.get_value(&TestKey(101)).await, None);
let mut dbtx = db.begin_transaction().await;
assert_eq!(dbtx.get_value(&TestKey(100)).await, Some(TestVal(101)));
assert_eq!(dbtx.get_value(&TestKey(101)).await, Some(TestVal(102)));
let mut module_dbtx = module_db.begin_transaction().await;
module_dbtx.insert_entry(&TestKey(100), &TestVal(103)).await;
module_dbtx.insert_entry(&TestKey(101), &TestVal(104)).await;
module_dbtx.commit_tx().await;
let expected_keys = 2;
let mut dbtx = db.begin_transaction().await;
let returned_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
match key {
TestKey(100) => {
assert!(value.eq(&TestVal(101)));
}
TestKey(101) => {
assert!(value.eq(&TestVal(102)));
}
_ => {}
};
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
let removed = dbtx.remove_entry(&TestKey(100)).await;
assert_eq!(removed, Some(TestVal(101)));
assert_eq!(dbtx.get_value(&TestKey(100)).await, None);
let mut module_dbtx = module_db.begin_transaction().await;
assert_eq!(
module_dbtx.get_value(&TestKey(100)).await,
Some(TestVal(103))
);
}
pub async fn verify_module_prefix(db: Database) {
let mut test_dbtx = db.begin_transaction().await;
{
let mut test_module_dbtx = test_dbtx.to_ref_with_prefix_module_id(TEST_MODULE_PREFIX).0;
test_module_dbtx
.insert_entry(&TestKey(100), &TestVal(101))
.await;
test_module_dbtx
.insert_entry(&TestKey(101), &TestVal(102))
.await;
}
test_dbtx.commit_tx().await;
let mut alt_dbtx = db.begin_transaction().await;
{
let mut alt_module_dbtx = alt_dbtx.to_ref_with_prefix_module_id(ALT_MODULE_PREFIX).0;
alt_module_dbtx
.insert_entry(&TestKey(100), &TestVal(103))
.await;
alt_module_dbtx
.insert_entry(&TestKey(101), &TestVal(104))
.await;
}
alt_dbtx.commit_tx().await;
let mut test_dbtx = db.begin_transaction().await;
let mut test_module_dbtx = test_dbtx.to_ref_with_prefix_module_id(TEST_MODULE_PREFIX).0;
assert_eq!(
test_module_dbtx.get_value(&TestKey(100)).await,
Some(TestVal(101))
);
assert_eq!(
test_module_dbtx.get_value(&TestKey(101)).await,
Some(TestVal(102))
);
let expected_keys = 2;
let returned_keys = test_module_dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.fold(0, |returned_keys, (key, value)| async move {
match key {
TestKey(100) => {
assert!(value.eq(&TestVal(101)));
}
TestKey(101) => {
assert!(value.eq(&TestVal(102)));
}
_ => {}
};
returned_keys + 1
})
.await;
assert_eq!(returned_keys, expected_keys);
let removed = test_module_dbtx.remove_entry(&TestKey(100)).await;
assert_eq!(removed, Some(TestVal(101)));
assert_eq!(test_module_dbtx.get_value(&TestKey(100)).await, None);
let mut test_dbtx = db.begin_transaction().await;
assert_eq!(test_dbtx.get_value(&TestKey(101)).await, None);
test_dbtx.commit_tx().await;
}
#[cfg(test)]
#[tokio::test]
pub async fn verify_test_migration() {
let db = Database::new(MemDatabase::new(), ModuleDecoderRegistry::default());
let expected_test_keys_size: usize = 100;
let mut dbtx = db.begin_transaction().await;
for i in 0..expected_test_keys_size {
dbtx.insert_new_entry(&TestKeyV0(i as u64, (i + 1) as u64), &TestVal(i as u64))
.await;
}
dbtx.insert_new_entry(&DatabaseVersionKeyV0, &DatabaseVersion(0))
.await;
dbtx.commit_tx().await;
let mut migrations: BTreeMap<DatabaseVersion, CoreMigrationFn> = BTreeMap::new();
migrations.insert(DatabaseVersion(0), |ctx| {
migrate_test_db_version_0(ctx).boxed()
});
apply_migrations(&db, "TestModule".to_string(), migrations, None, None)
.await
.expect("Error applying migrations for TestModule");
let mut dbtx = db.begin_transaction().await;
assert!(dbtx
.get_value(&DatabaseVersionKey(MODULE_GLOBAL_PREFIX.into()))
.await
.is_some());
let test_keys = dbtx
.find_by_prefix(&DbPrefixTestPrefix)
.await
.collect::<Vec<_>>()
.await;
let test_keys_size = test_keys.len();
assert_eq!(test_keys_size, expected_test_keys_size);
for (key, val) in test_keys {
assert_eq!(key.0, val.0 + 1);
}
}
#[allow(dead_code)]
async fn migrate_test_db_version_0(mut ctx: MigrationContext<'_>) -> Result<(), anyhow::Error> {
let mut dbtx = ctx.dbtx();
let example_keys_v0 = dbtx
.find_by_prefix(&DbPrefixTestPrefixV0)
.await
.collect::<Vec<_>>()
.await;
dbtx.remove_by_prefix(&DbPrefixTestPrefixV0).await;
for (key, val) in example_keys_v0 {
let key_v2 = TestKey(key.1);
dbtx.insert_new_entry(&key_v2, &val).await;
}
Ok(())
}
#[cfg(test)]
#[tokio::test]
async fn test_autocommit() {
use std::marker::PhantomData;
use std::ops::Range;
use std::path::Path;
use anyhow::anyhow;
use async_trait::async_trait;
use crate::db::{
AutocommitError, BaseDatabaseTransaction, IDatabaseTransaction,
IDatabaseTransactionOps, IDatabaseTransactionOpsCore, IRawDatabase,
IRawDatabaseTransaction,
};
use crate::ModuleDecoderRegistry;
#[derive(Debug)]
struct FakeDatabase;
#[async_trait]
impl IRawDatabase for FakeDatabase {
type Transaction<'a> = FakeTransaction<'a>;
async fn begin_transaction(&self) -> FakeTransaction {
FakeTransaction(PhantomData)
}
fn checkpoint(&self, _backup_path: &Path) -> anyhow::Result<()> {
Ok(())
}
}
#[derive(Debug)]
struct FakeTransaction<'a>(PhantomData<&'a ()>);
#[async_trait]
impl<'a> IDatabaseTransactionOpsCore for FakeTransaction<'a> {
async fn raw_insert_bytes(
&mut self,
_key: &[u8],
_value: &[u8],
) -> anyhow::Result<Option<Vec<u8>>> {
unimplemented!()
}
async fn raw_get_bytes(&mut self, _key: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
unimplemented!()
}
async fn raw_remove_entry(&mut self, _key: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
unimplemented!()
}
async fn raw_find_by_range(
&mut self,
_key_range: Range<&[u8]>,
) -> anyhow::Result<crate::db::PrefixStream<'_>> {
unimplemented!()
}
async fn raw_find_by_prefix(
&mut self,
_key_prefix: &[u8],
) -> anyhow::Result<crate::db::PrefixStream<'_>> {
unimplemented!()
}
async fn raw_remove_by_prefix(&mut self, _key_prefix: &[u8]) -> anyhow::Result<()> {
unimplemented!()
}
async fn raw_find_by_prefix_sorted_descending(
&mut self,
_key_prefix: &[u8],
) -> anyhow::Result<crate::db::PrefixStream<'_>> {
unimplemented!()
}
}
#[async_trait]
impl<'a> IDatabaseTransactionOps for FakeTransaction<'a> {
async fn rollback_tx_to_savepoint(&mut self) -> anyhow::Result<()> {
unimplemented!()
}
async fn set_tx_savepoint(&mut self) -> anyhow::Result<()> {
unimplemented!()
}
}
#[async_trait]
impl<'a> IRawDatabaseTransaction for FakeTransaction<'a> {
async fn commit_tx(self) -> anyhow::Result<()> {
Err(anyhow!("Can't commit!"))
}
}
let db = Database::new(FakeDatabase, ModuleDecoderRegistry::default());
let err = db
.autocommit::<_, _, ()>(|_dbtx, _| Box::pin(async { Ok(()) }), Some(5))
.await
.unwrap_err();
match err {
AutocommitError::CommitFailed {
attempts: failed_attempts,
..
} => {
assert_eq!(failed_attempts, 5);
}
AutocommitError::ClosureError { .. } => panic!("Closure did not return error"),
}
}
}
pub async fn find_by_prefix_sorted_descending<'r, 'inner, KP>(
tx: &'r mut (dyn IDatabaseTransaction + 'inner),
decoders: ModuleDecoderRegistry,
key_prefix: &KP,
) -> impl Stream<
Item = (
KP::Record,
<<KP as DatabaseLookup>::Record as DatabaseRecord>::Value,
),
> + 'r
where
'inner: 'r,
KP: DatabaseLookup,
KP::Record: DatabaseKey,
{
debug!("find by prefix sorted descending");
let prefix_bytes = key_prefix.to_bytes();
tx.raw_find_by_prefix_sorted_descending(&prefix_bytes)
.await
.expect("Error doing prefix search in database")
.map(move |(key_bytes, value_bytes)| {
let key = decode_key_expect(&key_bytes, &decoders);
let value = decode_value_expect(&value_bytes, &decoders, &key_bytes);
(key, value)
})
}
#[cfg(test)]
mod tests {
use tokio::sync::oneshot;
use super::mem_impl::MemDatabase;
use super::*;
use crate::runtime::spawn;
async fn waiter(db: &Database, key: TestKey) -> tokio::task::JoinHandle<TestVal> {
let db = db.clone();
let (tx, rx) = oneshot::channel::<()>();
let join_handle = spawn("wait key exists", async move {
let sub = db.wait_key_exists(&key);
tx.send(()).unwrap();
sub.await
});
rx.await.unwrap();
join_handle
}
#[tokio::test]
async fn test_wait_key_before_transaction() {
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let key_task = waiter(&db, TestKey(1)).await;
let mut tx = db.begin_transaction().await;
tx.insert_new_entry(&key, &val).await;
tx.commit_tx().await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_before_insert() {
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let mut tx = db.begin_transaction().await;
let key_task = waiter(&db, TestKey(1)).await;
tx.insert_new_entry(&key, &val).await;
tx.commit_tx().await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_after_insert() {
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let mut tx = db.begin_transaction().await;
tx.insert_new_entry(&key, &val).await;
let key_task = waiter(&db, TestKey(1)).await;
tx.commit_tx().await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_after_commit() {
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let mut tx = db.begin_transaction().await;
tx.insert_new_entry(&key, &val).await;
tx.commit_tx().await;
let key_task = waiter(&db, TestKey(1)).await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_isolated_db() {
let module_instance_id = 10;
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let db = db.with_prefix_module_id(module_instance_id).0;
let key_task = waiter(&db, TestKey(1)).await;
let mut tx = db.begin_transaction().await;
tx.insert_new_entry(&key, &val).await;
tx.commit_tx().await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_isolated_tx() {
let module_instance_id = 10;
let key = TestKey(1);
let val = TestVal(2);
let db = MemDatabase::new().into_database();
let key_task = waiter(&db.with_prefix_module_id(module_instance_id).0, TestKey(1)).await;
let mut tx = db.begin_transaction().await;
let mut tx_mod = tx.to_ref_with_prefix_module_id(module_instance_id).0;
tx_mod.insert_new_entry(&key, &val).await;
drop(tx_mod);
tx.commit_tx().await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
Some(TestVal(2)),
"should notify"
);
}
#[tokio::test]
async fn test_wait_key_no_transaction() {
let db = MemDatabase::new().into_database();
let key_task = waiter(&db, TestKey(1)).await;
assert_eq!(
future_returns_shortly(async { key_task.await.unwrap() }).await,
None,
"should not notify"
);
}
#[tokio::test]
async fn test_prefix_global_dbtx() {
let module_instance_id = 10;
let db = MemDatabase::new().into_database();
{
let (db, access_token) = db.with_prefix_module_id(module_instance_id);
let mut tx = db.begin_transaction().await;
let mut tx = tx.global_dbtx(access_token);
tx.insert_new_entry(&TestKey(1), &TestVal(1)).await;
tx.commit_tx().await;
}
assert_eq!(
db.begin_transaction_nc().await.get_value(&TestKey(1)).await,
Some(TestVal(1))
);
{
let (db, access_token) = db.with_prefix_module_id(module_instance_id);
let db = db.with_prefix(vec![3, 4]);
let mut tx = db.begin_transaction().await;
let mut tx = tx.global_dbtx(access_token);
tx.insert_new_entry(&TestKey(2), &TestVal(2)).await;
tx.commit_tx().await;
}
assert_eq!(
db.begin_transaction_nc().await.get_value(&TestKey(2)).await,
Some(TestVal(2))
);
}
#[tokio::test]
#[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")]
async fn test_prefix_global_dbtx_panics_on_global_db() {
let db = MemDatabase::new().into_database();
let mut tx = db.begin_transaction().await;
let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&[1]));
}
#[tokio::test]
#[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")]
async fn test_prefix_global_dbtx_panics_on_non_module_prefix() {
let db = MemDatabase::new().into_database();
let prefix = vec![3, 4];
let db = db.with_prefix(prefix.clone());
let mut tx = db.begin_transaction().await;
let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&prefix));
}
#[tokio::test]
#[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")]
async fn test_prefix_global_dbtx_panics_on_wrong_access_token() {
let db = MemDatabase::new().into_database();
let prefix = vec![3, 4];
let db = db.with_prefix(prefix.clone());
let mut tx = db.begin_transaction().await;
let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&[1]));
}
}