1pub use alloy_eips::BlockId;
2use alloy_provider::{
3 network::{primitives::HeaderResponse, BlockResponse},
4 Network, Provider,
5};
6use alloy_transport::TransportError;
7use core::error::Error;
8use database_interface::{async_db::DatabaseAsyncRef, DBErrorMarker};
9use primitives::{Address, B256, U256};
10use state::{AccountInfo, Bytecode};
11use std::fmt::Display;
12
13#[derive(Debug)]
14pub struct DBTransportError(pub TransportError);
15
16impl DBErrorMarker for DBTransportError {}
17
18impl Display for DBTransportError {
19 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
20 write!(f, "Transport error: {}", self.0)
21 }
22}
23
24impl Error for DBTransportError {}
25
26impl From<TransportError> for DBTransportError {
27 fn from(e: TransportError) -> Self {
28 Self(e)
29 }
30}
31
32#[derive(Debug)]
36pub struct AlloyDB<N: Network, P: Provider<N>> {
37 provider: P,
39 block_number: BlockId,
41 _marker: core::marker::PhantomData<fn() -> N>,
42}
43
44impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
45 pub fn new(provider: P, block_number: BlockId) -> Self {
47 Self {
48 provider,
49 block_number,
50 _marker: core::marker::PhantomData,
51 }
52 }
53
54 pub fn set_block_number(&mut self, block_number: BlockId) {
56 self.block_number = block_number;
57 }
58}
59
60impl<N: Network, P: Provider<N>> DatabaseAsyncRef for AlloyDB<N, P> {
61 type Error = DBTransportError;
62
63 async fn basic_async_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
64 let nonce = self
65 .provider
66 .get_transaction_count(address)
67 .block_id(self.block_number);
68 let balance = self
69 .provider
70 .get_balance(address)
71 .block_id(self.block_number);
72 let code = self
73 .provider
74 .get_code_at(address)
75 .block_id(self.block_number);
76
77 let (nonce, balance, code) = tokio::join!(nonce, balance, code,);
78
79 let balance = balance?;
80 let code = Bytecode::new_raw(code?.0.into());
81 let code_hash = code.hash_slow();
82 let nonce = nonce?;
83
84 Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
85 }
86
87 async fn block_hash_async_ref(&self, number: u64) -> Result<B256, Self::Error> {
88 let block = self
89 .provider
90 .get_block_by_number(number.into())
92 .await?;
93 Ok(B256::new(*block.unwrap().header().hash()))
95 }
96
97 async fn code_by_hash_async_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
98 panic!("This should not be called, as the code is already loaded");
99 }
101
102 async fn storage_async_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
103 Ok(self
104 .provider
105 .get_storage_at(address, index)
106 .block_id(self.block_number)
107 .await?)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use alloy_provider::ProviderBuilder;
115 use database_interface::{DatabaseRef, WrapDatabaseAsync};
116
117 #[test]
118 #[ignore = "flaky RPC"]
119 fn can_get_basic() {
120 let client = ProviderBuilder::new().on_http(
121 "https://mainnet.infura.io/v3/c60b0bb42f8a4c6481ecd229eddaca27"
122 .parse()
123 .unwrap(),
124 );
125 let alloydb = AlloyDB::new(client, BlockId::from(16148323));
126 let wrapped_alloydb = WrapDatabaseAsync::new(alloydb).unwrap();
127
128 let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852"
130 .parse()
131 .unwrap();
132
133 let acc_info = wrapped_alloydb.basic_ref(address).unwrap().unwrap();
134 assert!(acc_info.exists());
135 }
136}