1use core::{
3 error::Error,
4 fmt::Display,
5 ops::{Deref, DerefMut},
6};
7use primitives::{Address, StorageKey, StorageValue, B256};
8use state::{
9 bal::{alloy::AlloyBal, Bal, BalError},
10 Account, AccountInfo, Bytecode, EvmState,
11};
12use std::sync::Arc;
13
14use crate::{DBErrorMarker, Database, DatabaseCommit};
15
16#[derive(Clone, Default, Debug, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub struct BalState {
20 pub bal: Option<Arc<Bal>>,
22 pub bal_builder: Option<Bal>,
25 pub bal_index: u64,
28}
29
30impl BalState {
31 #[inline]
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 #[inline]
39 pub fn reset_bal_index(&mut self) {
40 self.bal_index = 0;
41 }
42
43 #[inline]
45 pub fn bump_bal_index(&mut self) {
46 self.bal_index += 1;
47 }
48
49 #[inline]
51 pub fn bal_index(&self) -> u64 {
52 self.bal_index
53 }
54
55 #[inline]
57 pub fn bal(&self) -> Option<Arc<Bal>> {
58 self.bal.clone()
59 }
60
61 #[inline]
63 pub fn bal_builder(&self) -> Option<Bal> {
64 self.bal_builder.clone()
65 }
66
67 #[inline]
69 pub fn with_bal(mut self, bal: Arc<Bal>) -> Self {
70 self.bal = Some(bal);
71 self
72 }
73
74 #[inline]
76 pub fn with_bal_builder(mut self) -> Self {
77 self.bal_builder = Some(Bal::new());
78 self
79 }
80
81 #[inline]
83 pub fn take_built_bal(&mut self) -> Option<Bal> {
84 self.reset_bal_index();
85 self.bal_builder.take()
86 }
87
88 #[inline]
90 pub fn take_built_alloy_bal(&mut self) -> Option<AlloyBal> {
91 self.take_built_bal().map(|bal| bal.into_alloy_bal())
92 }
93
94 #[inline]
98 pub fn get_account_id(&self, address: &Address) -> Result<Option<usize>, BalError> {
99 self.bal
100 .as_ref()
101 .map(|bal| {
102 bal.accounts
103 .get_full(address)
104 .map(|i| i.0)
105 .ok_or(BalError::AccountNotFound)
106 })
107 .transpose()
108 }
109
110 #[inline]
116 pub fn basic(
117 &self,
118 address: Address,
119 basic: &mut Option<AccountInfo>,
120 ) -> Result<bool, BalError> {
121 let Some(account_id) = self.get_account_id(&address)? else {
122 return Ok(false);
123 };
124 Ok(self.basic_by_account_id(account_id, basic))
125 }
126
127 #[inline]
131 pub fn basic_by_account_id(&self, account_id: usize, basic: &mut Option<AccountInfo>) -> bool {
132 if let Some(bal) = &self.bal {
133 let is_none = basic.is_none();
134 let mut bal_basic = core::mem::take(basic).unwrap_or_default();
135 bal.populate_account_info(account_id, self.bal_index, &mut bal_basic)
136 .expect("Invalid account id");
137
138 if is_none {
140 return true;
141 }
142
143 *basic = Some(bal_basic);
144 return true;
145 }
146 false
147 }
148
149 #[inline]
153 pub fn storage(
154 &self,
155 account: &Address,
156 storage_key: StorageKey,
157 ) -> Result<Option<StorageValue>, BalError> {
158 let Some(bal) = &self.bal else {
159 return Ok(None);
160 };
161
162 let Some(bal_account) = bal.accounts.get(account) else {
163 return Err(BalError::AccountNotFound);
164 };
165
166 Ok(bal_account
167 .storage
168 .get_bal_writes(storage_key)?
169 .get(self.bal_index))
170 }
171
172 #[inline]
178 pub fn storage_by_account_id(
179 &self,
180 account_id: usize,
181 storage_key: StorageKey,
182 ) -> Result<Option<StorageValue>, BalError> {
183 let Some(bal) = &self.bal else {
184 return Ok(None);
185 };
186
187 let Some((_, bal_account)) = bal.accounts.get_index(account_id) else {
188 return Err(BalError::AccountNotFound);
189 };
190
191 Ok(bal_account
192 .storage
193 .get_bal_writes(storage_key)?
194 .get(self.bal_index))
195 }
196
197 #[inline]
199 pub fn commit(&mut self, changes: &EvmState) {
200 if let Some(bal_builder) = &mut self.bal_builder {
201 for (address, account) in changes.iter() {
202 bal_builder.update_account(self.bal_index, *address, account);
203 }
204 }
205 }
206
207 #[inline]
209 pub fn commit_one(&mut self, address: Address, account: &Account) {
210 if let Some(bal_builder) = &mut self.bal_builder {
211 bal_builder.update_account(self.bal_index, address, account);
212 }
213 }
214}
215
216#[derive(Clone, Debug, PartialEq, Eq)]
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
219pub struct BalDatabase<DB> {
220 pub bal_state: BalState,
222 pub db: DB,
224}
225
226impl<DB> Deref for BalDatabase<DB> {
227 type Target = DB;
228
229 fn deref(&self) -> &Self::Target {
230 &self.db
231 }
232}
233
234impl<DB> DerefMut for BalDatabase<DB> {
235 fn deref_mut(&mut self) -> &mut Self::Target {
236 &mut self.db
237 }
238}
239
240impl<DB> BalDatabase<DB> {
241 #[inline]
243 pub fn new(db: DB) -> Self {
244 Self {
245 bal_state: BalState::default(),
246 db,
247 }
248 }
249
250 #[inline]
252 pub fn with_bal_option(self, bal: Option<Arc<Bal>>) -> Self {
253 Self {
254 bal_state: BalState {
255 bal,
256 ..self.bal_state
257 },
258 ..self
259 }
260 }
261
262 #[inline]
264 pub fn with_bal_builder(self) -> Self {
265 Self {
266 bal_state: self.bal_state.with_bal_builder(),
267 ..self
268 }
269 }
270
271 #[inline]
273 pub fn reset_bal_index(mut self) -> Self {
274 self.bal_state.reset_bal_index();
275 self
276 }
277
278 #[inline]
280 pub fn bump_bal_index(&mut self) {
281 self.bal_state.bump_bal_index();
282 }
283}
284
285#[derive(Clone, Debug, PartialEq, Eq)]
287#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
288pub enum EvmDatabaseError<ERROR> {
289 Bal(BalError),
291 Database(ERROR),
293}
294
295impl<ERROR> From<BalError> for EvmDatabaseError<ERROR> {
296 fn from(error: BalError) -> Self {
297 Self::Bal(error)
298 }
299}
300
301impl<ERROR: core::error::Error + Send + Sync + 'static> DBErrorMarker for EvmDatabaseError<ERROR> {}
302
303impl<ERROR: Display> Display for EvmDatabaseError<ERROR> {
304 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
305 match self {
306 Self::Bal(error) => write!(f, "Bal error: {error}"),
307 Self::Database(error) => write!(f, "Database error: {error}"),
308 }
309 }
310}
311
312impl<ERROR: Error> Error for EvmDatabaseError<ERROR> {}
313
314impl<ERROR> EvmDatabaseError<ERROR> {
315 pub fn into_external_error(self) -> ERROR {
319 match self {
320 Self::Bal(_) => panic!("Expected database error, got BAL error"),
321 Self::Database(error) => error,
322 }
323 }
324}
325
326impl<DB: Database> Database for BalDatabase<DB> {
327 type Error = EvmDatabaseError<DB::Error>;
328
329 #[inline]
330 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
331 let account_id = self.bal_state.get_account_id(&address)?;
332
333 let mut account = self.db.basic(address).map_err(EvmDatabaseError::Database)?;
334
335 if let Some(account_id) = account_id {
336 self.bal_state.basic_by_account_id(account_id, &mut account);
337 }
338
339 Ok(account)
340 }
341
342 #[inline]
343 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
344 self.db
345 .code_by_hash(code_hash)
346 .map_err(EvmDatabaseError::Database)
347 }
348
349 #[inline]
350 fn storage(&mut self, address: Address, key: StorageKey) -> Result<StorageValue, Self::Error> {
351 if let Some(storage) = self.bal_state.storage(&address, key)? {
352 return Ok(storage);
353 }
354
355 self.db
356 .storage(address, key)
357 .map_err(EvmDatabaseError::Database)
358 }
359
360 #[inline]
361 fn storage_by_account_id(
362 &mut self,
363 address: Address,
364 account_id: usize,
365 storage_key: StorageKey,
366 ) -> Result<StorageValue, Self::Error> {
367 if let Some(value) = self
368 .bal_state
369 .storage_by_account_id(account_id, storage_key)?
370 {
371 return Ok(value);
372 }
373
374 self.db
375 .storage(address, storage_key)
376 .map_err(EvmDatabaseError::Database)
377 }
378
379 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
380 self.db
381 .block_hash(number)
382 .map_err(EvmDatabaseError::Database)
383 }
384}
385
386impl<DB: DatabaseCommit> DatabaseCommit for BalDatabase<DB> {
387 fn commit(&mut self, changes: EvmState) {
388 self.bal_state.commit(&changes);
389 self.db.commit(changes);
390 }
391}