1pub use alloy_eips::BlockId;
4use alloy_provider::{
5 network::{primitives::HeaderResponse, BlockResponse},
6 Network, Provider,
7};
8use alloy_transport::TransportError;
9use core::error::Error;
10use database_interface::{async_db::DatabaseAsyncRef, DBErrorMarker};
11use primitives::{Address, StorageKey, StorageValue, B256};
12use state::{AccountInfo, Bytecode};
13use std::fmt::Display;
14
15#[derive(Debug)]
17pub struct DBTransportError(pub TransportError);
18
19impl DBErrorMarker for DBTransportError {}
20
21impl Display for DBTransportError {
22 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23 write!(f, "Transport error: {}", self.0)
24 }
25}
26
27impl Error for DBTransportError {}
28
29impl From<TransportError> for DBTransportError {
30 fn from(e: TransportError) -> Self {
31 Self(e)
32 }
33}
34
35#[derive(Debug)]
39pub struct AlloyDB<N: Network, P: Provider<N>> {
40 provider: P,
42 block_number: BlockId,
44 _marker: core::marker::PhantomData<fn() -> N>,
45}
46
47impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
48 pub fn new(provider: P, block_number: BlockId) -> Self {
50 Self {
51 provider,
52 block_number,
53 _marker: core::marker::PhantomData,
54 }
55 }
56
57 pub fn set_block_number(&mut self, block_number: BlockId) {
59 self.block_number = block_number;
60 }
61}
62
63impl<N: Network, P: Provider<N>> DatabaseAsyncRef for AlloyDB<N, P> {
64 type Error = DBTransportError;
65
66 async fn basic_async_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
67 let nonce = self
68 .provider
69 .get_transaction_count(address)
70 .block_id(self.block_number);
71 let balance = self
72 .provider
73 .get_balance(address)
74 .block_id(self.block_number);
75 let code = self
76 .provider
77 .get_code_at(address)
78 .block_id(self.block_number);
79
80 let (nonce, balance, code) = tokio::join!(nonce, balance, code,);
81
82 let balance = balance?;
83 let code = Bytecode::new_raw(code?.0.into());
84 let code_hash = code.hash_slow();
85 let nonce = nonce?;
86
87 Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
88 }
89
90 async fn block_hash_async_ref(&self, number: u64) -> Result<B256, Self::Error> {
91 let block = self
92 .provider
93 .get_block_by_number(number.into())
95 .await?;
96 Ok(B256::new(*block.unwrap().header().hash()))
98 }
99
100 async fn code_by_hash_async_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
101 panic!("This should not be called, as the code is already loaded");
102 }
104
105 async fn storage_async_ref(
106 &self,
107 address: Address,
108 index: StorageKey,
109 ) -> Result<StorageValue, Self::Error> {
110 Ok(self
111 .provider
112 .get_storage_at(address, index)
113 .block_id(self.block_number)
114 .await?)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use alloy_provider::ProviderBuilder;
122 use database_interface::{DatabaseRef, WrapDatabaseAsync};
123
124 #[tokio::test]
125 #[ignore = "flaky RPC"]
126 async fn can_get_basic() {
127 let client = ProviderBuilder::new()
128 .connect("https://mainnet.infura.io/v3/c60b0bb42f8a4c6481ecd229eddaca27")
129 .await
130 .unwrap()
131 .erased();
132 let alloydb = AlloyDB::new(client, BlockId::from(16148323));
133 let wrapped_alloydb = WrapDatabaseAsync::new(alloydb).unwrap();
134
135 let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852"
137 .parse()
138 .unwrap();
139
140 let acc_info = wrapped_alloydb.basic_ref(address).unwrap().unwrap();
141 assert!(acc_info.exists());
142 }
143}