revm_database/
alloydb.rs

1pub use alloy_eips::BlockId;
2use alloy_provider::{
3    network::{
4        primitives::{BlockTransactionsKind, HeaderResponse},
5        BlockResponse,
6    },
7    Network, Provider,
8};
9use alloy_transport::TransportError;
10use core::error::Error;
11use database_interface::{async_db::DatabaseAsyncRef, DBErrorMarker};
12use primitives::{Address, B256, U256};
13use state::{AccountInfo, Bytecode};
14use std::fmt::Display;
15
16#[derive(Debug)]
17pub struct DBTransportError(pub TransportError);
18
19impl DBErrorMarker for DBTransportError {}
20
21impl Display for DBTransportError {
22    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23        write!(f, "Transport error: {}", self.0)
24    }
25}
26
27impl Error for DBTransportError {}
28
29impl From<TransportError> for DBTransportError {
30    fn from(e: TransportError) -> Self {
31        Self(e)
32    }
33}
34
35/// An alloy-powered REVM [Database][database_interface::Database].
36///
37/// When accessing the database, it'll use the given provider to fetch the corresponding account's data.
38#[derive(Debug)]
39pub struct AlloyDB<N: Network, P: Provider<N>> {
40    /// The provider to fetch the data from.
41    provider: P,
42    /// The block number on which the queries will be based on.
43    block_number: BlockId,
44    _marker: core::marker::PhantomData<fn() -> N>,
45}
46
47impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
48    /// Creates a new AlloyDB instance, with a [Provider] and a block.
49    pub fn new(provider: P, block_number: BlockId) -> Self {
50        Self {
51            provider,
52            block_number,
53            _marker: core::marker::PhantomData,
54        }
55    }
56
57    /// Sets the block number on which the queries will be based on.
58    pub fn set_block_number(&mut self, block_number: BlockId) {
59        self.block_number = block_number;
60    }
61}
62
63impl<N: Network, P: Provider<N>> DatabaseAsyncRef for AlloyDB<N, P> {
64    type Error = DBTransportError;
65
66    async fn basic_async_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
67        let nonce = self
68            .provider
69            .get_transaction_count(address)
70            .block_id(self.block_number);
71        let balance = self
72            .provider
73            .get_balance(address)
74            .block_id(self.block_number);
75        let code = self
76            .provider
77            .get_code_at(address)
78            .block_id(self.block_number);
79
80        let (nonce, balance, code) = tokio::join!(nonce, balance, code,);
81
82        let balance = balance?;
83        let code = Bytecode::new_raw(code?.0.into());
84        let code_hash = code.hash_slow();
85        let nonce = nonce?;
86
87        Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
88    }
89
90    async fn block_hash_async_ref(&self, number: u64) -> Result<B256, Self::Error> {
91        let block = self
92            .provider
93            // SAFETY: We know number <= u64::MAX, so we can safely convert it to u64
94            .get_block_by_number(number.into(), BlockTransactionsKind::Hashes)
95            .await?;
96        // SAFETY: If the number is given, the block is supposed to be finalized, so unwrapping is safe.
97        Ok(B256::new(*block.unwrap().header().hash()))
98    }
99
100    async fn code_by_hash_async_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
101        panic!("This should not be called, as the code is already loaded");
102        // This is not needed, as the code is already loaded with basic_ref
103    }
104
105    async fn storage_async_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
106        Ok(self
107            .provider
108            .get_storage_at(address, index)
109            .block_id(self.block_number)
110            .await?)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use alloy_provider::ProviderBuilder;
118    use database_interface::{DatabaseRef, WrapDatabaseAsync};
119
120    #[test]
121    #[ignore = "flaky RPC"]
122    fn can_get_basic() {
123        let client = ProviderBuilder::new().on_http(
124            "https://mainnet.infura.io/v3/c60b0bb42f8a4c6481ecd229eddaca27"
125                .parse()
126                .unwrap(),
127        );
128        let alloydb = AlloyDB::new(client, BlockId::from(16148323));
129        let wrapped_alloydb = WrapDatabaseAsync::new(alloydb).unwrap();
130
131        // ETH/USDT pair on Uniswap V2
132        let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852"
133            .parse()
134            .unwrap();
135
136        let acc_info = wrapped_alloydb.basic_ref(address).unwrap().unwrap();
137        assert!(acc_info.exists());
138    }
139}