1pub use alloy_eips::BlockId;
4use alloy_provider::{
5 network::{primitives::HeaderResponse, BlockResponse},
6 Network, Provider,
7};
8use alloy_transport::TransportError;
9use core::error::Error;
10use database_interface::{async_db::DatabaseAsyncRef, DBErrorMarker};
11use primitives::{Address, StorageKey, StorageValue, B256};
12use state::{AccountInfo, Bytecode};
13use std::fmt::Display;
14
15#[derive(Debug)]
17pub enum AlloyDBError {
18 Transport(TransportError),
20 BlockNotFound(u64),
26}
27
28impl DBErrorMarker for AlloyDBError {}
29
30impl Display for AlloyDBError {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 match self {
33 Self::Transport(err) => write!(f, "Transport error: {err}"),
34 Self::BlockNotFound(number) => write!(f, "Block not found: {number}"),
35 }
36 }
37}
38
39impl Error for AlloyDBError {
40 fn source(&self) -> Option<&(dyn Error + 'static)> {
41 match self {
42 Self::Transport(err) => Some(err),
43 Self::BlockNotFound(_) => None,
44 }
45 }
46}
47
48impl From<TransportError> for AlloyDBError {
49 fn from(e: TransportError) -> Self {
50 Self::Transport(e)
51 }
52}
53
54#[derive(Debug)]
58pub struct AlloyDB<N: Network, P: Provider<N>> {
59 provider: P,
61 block_number: BlockId,
63 _marker: core::marker::PhantomData<fn() -> N>,
64}
65
66impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
67 pub fn new(provider: P, block_number: BlockId) -> Self {
69 Self {
70 provider,
71 block_number,
72 _marker: core::marker::PhantomData,
73 }
74 }
75
76 pub fn set_block_number(&mut self, block_number: BlockId) {
78 self.block_number = block_number;
79 }
80}
81
82impl<N: Network, P: Provider<N>> DatabaseAsyncRef for AlloyDB<N, P> {
83 type Error = AlloyDBError;
84
85 async fn basic_async_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
86 let nonce = self
87 .provider
88 .get_transaction_count(address)
89 .block_id(self.block_number);
90 let balance = self
91 .provider
92 .get_balance(address)
93 .block_id(self.block_number);
94 let code = self
95 .provider
96 .get_code_at(address)
97 .block_id(self.block_number);
98
99 let (nonce, balance, code) = tokio::join!(nonce, balance, code,);
100
101 let balance = balance?;
102 let code = Bytecode::new_raw(code?.0.into());
103 let code_hash = code.hash_slow();
104 let nonce = nonce?;
105
106 Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
107 }
108
109 async fn block_hash_async_ref(&self, number: u64) -> Result<B256, Self::Error> {
110 let block = self
111 .provider
112 .get_block_by_number(number.into())
114 .await?;
115
116 match block {
117 Some(block) => Ok(B256::new(*block.header().hash())),
118 None => Err(AlloyDBError::BlockNotFound(number)),
119 }
120 }
121
122 async fn code_by_hash_async_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
123 panic!("This should not be called, as the code is already loaded");
124 }
126
127 async fn storage_async_ref(
128 &self,
129 address: Address,
130 index: StorageKey,
131 ) -> Result<StorageValue, Self::Error> {
132 Ok(self
133 .provider
134 .get_storage_at(address, index)
135 .block_id(self.block_number)
136 .await?)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use alloy_provider::ProviderBuilder;
144 use database_interface::{DatabaseRef, WrapDatabaseAsync};
145
146 #[tokio::test]
147 #[ignore = "flaky RPC"]
148 async fn can_get_basic() {
149 let client = ProviderBuilder::new()
150 .connect("https://mainnet.infura.io/v3/c60b0bb42f8a4c6481ecd229eddaca27")
151 .await
152 .unwrap()
153 .erased();
154 let alloydb = AlloyDB::new(client, BlockId::from(16148323));
155 let wrapped_alloydb = WrapDatabaseAsync::new(alloydb).unwrap();
156
157 let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852"
159 .parse()
160 .unwrap();
161
162 let acc_info = wrapped_alloydb.basic_ref(address).unwrap().unwrap();
163 assert!(acc_info.exists());
164 }
165}