1pub use alloy_eips::BlockId;
2use alloy_provider::{
3 network::{
4 primitives::{BlockTransactionsKind, HeaderResponse},
5 BlockResponse,
6 },
7 Network, Provider,
8};
9use alloy_transport::TransportError;
10use core::error::Error;
11use database_interface::{async_db::DatabaseAsyncRef, DBErrorMarker};
12use primitives::{Address, B256, U256};
13use state::{AccountInfo, Bytecode};
14use std::fmt::Display;
15
16#[derive(Debug)]
17pub struct DBTransportError(pub TransportError);
18
19impl DBErrorMarker for DBTransportError {}
20
21impl Display for DBTransportError {
22 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23 write!(f, "Transport error: {}", self.0)
24 }
25}
26
27impl Error for DBTransportError {}
28
29impl From<TransportError> for DBTransportError {
30 fn from(e: TransportError) -> Self {
31 Self(e)
32 }
33}
34
35#[derive(Debug)]
39pub struct AlloyDB<N: Network, P: Provider<N>> {
40 provider: P,
42 block_number: BlockId,
44 _marker: core::marker::PhantomData<fn() -> N>,
45}
46
47impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
48 pub fn new(provider: P, block_number: BlockId) -> Self {
50 Self {
51 provider,
52 block_number,
53 _marker: core::marker::PhantomData,
54 }
55 }
56
57 pub fn set_block_number(&mut self, block_number: BlockId) {
59 self.block_number = block_number;
60 }
61}
62
63impl<N: Network, P: Provider<N>> DatabaseAsyncRef for AlloyDB<N, P> {
64 type Error = DBTransportError;
65
66 async fn basic_async_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
67 let nonce = self
68 .provider
69 .get_transaction_count(address)
70 .block_id(self.block_number);
71 let balance = self
72 .provider
73 .get_balance(address)
74 .block_id(self.block_number);
75 let code = self
76 .provider
77 .get_code_at(address)
78 .block_id(self.block_number);
79
80 let (nonce, balance, code) = tokio::join!(nonce, balance, code,);
81
82 let balance = balance?;
83 let code = Bytecode::new_raw(code?.0.into());
84 let code_hash = code.hash_slow();
85 let nonce = nonce?;
86
87 Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
88 }
89
90 async fn block_hash_async_ref(&self, number: u64) -> Result<B256, Self::Error> {
91 let block = self
92 .provider
93 .get_block_by_number(number.into(), BlockTransactionsKind::Hashes)
95 .await?;
96 Ok(B256::new(*block.unwrap().header().hash()))
98 }
99
100 async fn code_by_hash_async_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
101 panic!("This should not be called, as the code is already loaded");
102 }
104
105 async fn storage_async_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
106 Ok(self
107 .provider
108 .get_storage_at(address, index)
109 .block_id(self.block_number)
110 .await?)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use alloy_provider::ProviderBuilder;
118 use database_interface::{DatabaseRef, WrapDatabaseAsync};
119
120 #[test]
121 #[ignore = "flaky RPC"]
122 fn can_get_basic() {
123 let client = ProviderBuilder::new().on_http(
124 "https://mainnet.infura.io/v3/c60b0bb42f8a4c6481ecd229eddaca27"
125 .parse()
126 .unwrap(),
127 );
128 let alloydb = AlloyDB::new(client, BlockId::from(16148323));
129 let wrapped_alloydb = WrapDatabaseAsync::new(alloydb).unwrap();
130
131 let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852"
133 .parse()
134 .unwrap();
135
136 let acc_info = wrapped_alloydb.basic_ref(address).unwrap().unwrap();
137 assert!(acc_info.exists());
138 }
139}