1use core::{
3 error::Error,
4 fmt::Display,
5 ops::{Deref, DerefMut},
6};
7use primitives::{Address, StorageKey, StorageValue, B256};
8use state::{
9 bal::{alloy::AlloyBal, Bal, BalError, BlockAccessIndex},
10 Account, AccountId, AccountInfo, Bytecode, EvmState,
11};
12use std::sync::Arc;
13
14use crate::{DBErrorMarker, Database, DatabaseCommit};
15
16#[derive(Clone, Default, Debug, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub struct BalState {
20 pub bal: Option<Arc<Bal>>,
22 pub bal_builder: Option<Bal>,
25 pub bal_index: BlockAccessIndex,
28 #[cfg_attr(feature = "serde", serde(default))]
36 pub allow_db_fallback: bool,
37}
38
39impl BalState {
40 #[inline]
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 #[inline]
48 pub const fn reset_bal_index(&mut self) {
49 self.bal_index = BlockAccessIndex::PRE_EXECUTION;
50 }
51
52 #[inline]
54 pub const fn bump_bal_index(&mut self) {
55 self.bal_index.increment();
56 }
57
58 #[inline]
60 pub const fn bal_index(&self) -> BlockAccessIndex {
61 self.bal_index
62 }
63
64 #[inline]
66 pub fn bal(&self) -> Option<Arc<Bal>> {
67 self.bal.clone()
68 }
69
70 #[inline]
72 pub fn bal_builder(&self) -> Option<Bal> {
73 self.bal_builder.clone()
74 }
75
76 #[inline]
78 pub fn with_bal(mut self, bal: Arc<Bal>) -> Self {
79 self.bal = Some(bal);
80 self
81 }
82
83 #[inline]
85 pub fn with_bal_builder(mut self) -> Self {
86 self.bal_builder = Some(Bal::new());
87 self
88 }
89
90 #[inline]
94 pub const fn with_allow_db_fallback(mut self, allow: bool) -> Self {
95 self.allow_db_fallback = allow;
96 self
97 }
98
99 #[inline]
103 pub const fn set_allow_db_fallback(&mut self, allow: bool) {
104 self.allow_db_fallback = allow;
105 }
106
107 #[inline]
109 pub const fn take_built_bal(&mut self) -> Option<Bal> {
110 self.reset_bal_index();
111 self.bal_builder.take()
112 }
113
114 #[inline]
116 pub fn take_built_alloy_bal(&mut self) -> Option<AlloyBal> {
117 self.take_built_bal().map(|bal| bal.into_alloy_bal())
118 }
119
120 #[inline]
128 pub fn get_account_id(&self, address: &Address) -> Result<Option<AccountId>, BalError> {
129 let Some(bal) = self.bal.as_ref() else {
130 return Ok(None);
131 };
132 match bal.accounts.get_full(address) {
133 Some(i) => Ok(Some(AccountId::new(i.0).expect("too many bals"))),
134 None if self.allow_db_fallback => Ok(None),
135 None => Err(BalError::AccountNotFound { address: *address }),
136 }
137 }
138
139 #[inline]
145 pub fn basic(
146 &self,
147 address: Address,
148 basic: &mut Option<AccountInfo>,
149 ) -> Result<bool, BalError> {
150 let Some(account_id) = self.get_account_id(&address)? else {
151 return Ok(false);
152 };
153 self.basic_by_account_id(account_id, basic)
154 }
155
156 #[inline]
158 pub fn basic_by_account_id(
159 &self,
160 account_id: AccountId,
161 basic: &mut Option<AccountInfo>,
162 ) -> Result<bool, BalError> {
163 let Some(bal) = &self.bal else {
164 return Ok(false);
165 };
166 let is_none = basic.is_none();
167 let mut bal_basic = core::mem::take(basic).unwrap_or_default();
168 let changed = bal.populate_account_info(account_id, self.bal_index, &mut bal_basic)?;
169
170 if !changed && is_none {
172 return Ok(true);
173 }
174
175 *basic = Some(bal_basic);
176 Ok(true)
177 }
178
179 #[inline]
187 pub fn storage(
188 &self,
189 account: &Address,
190 storage_key: StorageKey,
191 ) -> Result<Option<StorageValue>, BalError> {
192 let Some(bal) = &self.bal else {
193 return Ok(None);
194 };
195
196 let Some(bal_account) = bal.accounts.get(account) else {
197 if self.allow_db_fallback {
198 return Ok(None);
199 }
200 return Err(BalError::AccountNotFound { address: *account });
201 };
202
203 match bal_account.storage.get_bal_writes(account, storage_key) {
204 Ok(writes) => Ok(writes.get(self.bal_index)),
205 Err(BalError::SlotNotFound { .. }) if self.allow_db_fallback => Ok(None),
206 Err(err) => Err(err),
207 }
208 }
209
210 #[inline]
218 pub fn storage_by_account_id(
219 &self,
220 account_id: AccountId,
221 storage_key: StorageKey,
222 ) -> Result<Option<StorageValue>, BalError> {
223 let Some(bal) = &self.bal else {
224 return Ok(None);
225 };
226
227 let Some((address, bal_account)) = bal.accounts.get_index(account_id.get()) else {
228 return Err(BalError::InvalidAccountId { account_id });
229 };
230
231 match bal_account.storage.get_bal_writes(address, storage_key) {
232 Ok(writes) => Ok(writes.get(self.bal_index)),
233 Err(BalError::SlotNotFound { .. }) if self.allow_db_fallback => Ok(None),
234 Err(err) => Err(err),
235 }
236 }
237
238 #[inline]
240 pub fn commit(&mut self, changes: &EvmState) {
241 if let Some(bal_builder) = &mut self.bal_builder {
242 for (address, account) in changes.iter() {
243 bal_builder.update_account(self.bal_index, *address, account);
244 }
245 }
246 }
247
248 #[inline]
250 pub fn commit_one(&mut self, address: Address, account: &Account) {
251 if let Some(bal_builder) = &mut self.bal_builder {
252 bal_builder.update_account(self.bal_index, address, account);
253 }
254 }
255}
256
257#[derive(Clone, Debug, PartialEq, Eq)]
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260pub struct BalDatabase<DB> {
261 pub bal_state: BalState,
263 pub db: DB,
265}
266
267impl<DB> Deref for BalDatabase<DB> {
268 type Target = DB;
269
270 fn deref(&self) -> &Self::Target {
271 &self.db
272 }
273}
274
275impl<DB> DerefMut for BalDatabase<DB> {
276 fn deref_mut(&mut self) -> &mut Self::Target {
277 &mut self.db
278 }
279}
280
281impl<DB> BalDatabase<DB> {
282 #[inline]
284 pub fn new(db: DB) -> Self {
285 Self {
286 bal_state: BalState::default(),
287 db,
288 }
289 }
290
291 #[inline]
293 pub fn with_bal_option(self, bal: Option<Arc<Bal>>) -> Self {
294 Self {
295 bal_state: BalState {
296 bal,
297 ..self.bal_state
298 },
299 ..self
300 }
301 }
302
303 #[inline]
305 pub fn with_bal_builder(self) -> Self {
306 Self {
307 bal_state: self.bal_state.with_bal_builder(),
308 ..self
309 }
310 }
311
312 #[inline]
316 pub const fn with_allow_bal_db_fallback(mut self, allow: bool) -> Self {
317 self.bal_state.allow_db_fallback = allow;
318 self
319 }
320
321 #[inline]
323 pub const fn reset_bal_index(mut self) -> Self {
324 self.bal_state.reset_bal_index();
325 self
326 }
327
328 #[inline]
330 pub const fn bump_bal_index(&mut self) {
331 self.bal_state.bump_bal_index();
332 }
333}
334
335#[derive(Clone, Debug, PartialEq, Eq)]
337#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
338pub enum EvmDatabaseError<ERROR> {
339 Bal(BalError),
341 Database(ERROR),
343}
344
345impl<ERROR> From<BalError> for EvmDatabaseError<ERROR> {
346 fn from(error: BalError) -> Self {
347 Self::Bal(error)
348 }
349}
350
351impl<ERROR: core::error::Error + Send + Sync + 'static> DBErrorMarker for EvmDatabaseError<ERROR> {
352 fn is_fatal(&self) -> bool {
353 match self {
354 Self::Bal(_) => false,
355 Self::Database(_) => true,
356 }
357 }
358}
359
360impl<ERROR: Display> Display for EvmDatabaseError<ERROR> {
361 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
362 match self {
363 Self::Bal(error) => write!(f, "Bal error: {error}"),
364 Self::Database(error) => write!(f, "Database error: {error}"),
365 }
366 }
367}
368
369impl<ERROR: Error> Error for EvmDatabaseError<ERROR> {}
370
371impl<ERROR> EvmDatabaseError<ERROR> {
372 pub fn into_external_error(self) -> ERROR {
376 match self {
377 Self::Bal(_) => panic!("Expected database error, got BAL error"),
378 Self::Database(error) => error,
379 }
380 }
381}
382
383impl<DB: Database> Database for BalDatabase<DB> {
384 type Error = EvmDatabaseError<DB::Error>;
385
386 #[inline]
387 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
388 let account_id = self.bal_state.get_account_id(&address)?;
389
390 let mut account = self.db.basic(address).map_err(EvmDatabaseError::Database)?;
391
392 if let Some(account_id) = account_id {
393 self.bal_state
394 .basic_by_account_id(account_id, &mut account)?;
395 }
396
397 Ok(account)
398 }
399
400 #[inline]
401 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
402 self.db
403 .code_by_hash(code_hash)
404 .map_err(EvmDatabaseError::Database)
405 }
406
407 #[inline]
408 fn storage(&mut self, address: Address, key: StorageKey) -> Result<StorageValue, Self::Error> {
409 if let Some(storage) = self.bal_state.storage(&address, key)? {
410 return Ok(storage);
411 }
412
413 self.db
414 .storage(address, key)
415 .map_err(EvmDatabaseError::Database)
416 }
417
418 #[inline]
419 fn storage_by_account_id(
420 &mut self,
421 address: Address,
422 account_id: AccountId,
423 storage_key: StorageKey,
424 ) -> Result<StorageValue, Self::Error> {
425 if let Some(value) = self
426 .bal_state
427 .storage_by_account_id(account_id, storage_key)?
428 {
429 return Ok(value);
430 }
431
432 self.db
433 .storage(address, storage_key)
434 .map_err(EvmDatabaseError::Database)
435 }
436
437 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
438 self.db
439 .block_hash(number)
440 .map_err(EvmDatabaseError::Database)
441 }
442}
443
444impl<DB: DatabaseCommit> DatabaseCommit for BalDatabase<DB> {
445 fn commit(&mut self, changes: EvmState) {
446 self.bal_state.commit(&changes);
447 self.db.commit(changes);
448 }
449
450 fn commit_iter(&mut self, changes: &mut dyn Iterator<Item = (Address, Account)>) {
451 let bal_state = &mut self.bal_state;
452 let mut changes = changes.map(|(address, account)| {
453 bal_state.commit_one(address, &account);
454 (address, account)
455 });
456 self.db.commit_iter(&mut changes);
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use primitives::U256;
464 use state::bal::{AccountBal, BalWrites};
465
466 fn bal_with_account(address: Address, slot: StorageKey) -> Arc<Bal> {
467 let mut account = AccountBal::default();
468 account.storage.storage.insert(
469 slot,
470 BalWrites::new(vec![(BlockAccessIndex::new(1), StorageValue::from(42u64))]),
471 );
472 Arc::new(Bal::from_iter([(address, account)]))
473 }
474
475 #[test]
476 fn bal_misses_error_without_fallback() {
477 let address = Address::with_last_byte(1);
478 let missing = Address::with_last_byte(2);
479 let slot = U256::from(1);
480 let missing_slot = U256::from(2);
481 let bal_state = BalState::new().with_bal(bal_with_account(address, slot));
482
483 assert_eq!(
484 bal_state.get_account_id(&missing),
485 Err(BalError::AccountNotFound { address: missing })
486 );
487 assert_eq!(
488 bal_state.storage(&missing, slot),
489 Err(BalError::AccountNotFound { address: missing })
490 );
491 assert_eq!(
492 bal_state.storage(&address, missing_slot),
493 Err(BalError::SlotNotFound {
494 address,
495 slot: missing_slot
496 })
497 );
498 }
499
500 #[test]
501 fn bal_misses_fall_back_to_database_with_fallback() {
502 let address = Address::with_last_byte(1);
503 let missing = Address::with_last_byte(2);
504 let slot = U256::from(1);
505 let missing_slot = U256::from(2);
506 let mut bal_state = BalState::new()
507 .with_bal(bal_with_account(address, slot))
508 .with_allow_db_fallback(true);
509
510 assert_eq!(bal_state.get_account_id(&missing), Ok(None));
512 assert_eq!(bal_state.storage(&missing, slot), Ok(None));
513 assert_eq!(bal_state.storage(&address, missing_slot), Ok(None));
514
515 bal_state.bal_index = BlockAccessIndex::new(2);
517 assert!(bal_state.get_account_id(&address).unwrap().is_some());
518 assert_eq!(
519 bal_state.storage(&address, slot),
520 Ok(Some(StorageValue::from(42u64)))
521 );
522 }
523}