revme/cmd/statetest/
runner.rs

1use crate::cmd::statetest::merkle_trie::{compute_test_roots, TestValidationResult};
2use indicatif::{ProgressBar, ProgressDrawTarget};
3use revm::{
4    context::{block::BlockEnv, cfg::CfgEnv, tx::TxEnv},
5    context_interface::{
6        result::{EVMError, ExecutionResult, HaltReason, InvalidTransaction},
7        Cfg,
8    },
9    database,
10    database_interface::EmptyDB,
11    inspector::{inspectors::TracerEip3155, InspectCommitEvm},
12    primitives::{hardfork::SpecId, Bytes, B256, U256},
13    Context, ExecuteCommitEvm, MainBuilder, MainContext,
14};
15use serde_json::json;
16use statetest_types::{SpecName, Test, TestSuite, TestUnit};
17use std::{
18    convert::Infallible,
19    fmt::Debug,
20    io::stderr,
21    path::{Path, PathBuf},
22    sync::{
23        atomic::{AtomicBool, AtomicUsize, Ordering},
24        Arc, Mutex,
25    },
26    time::{Duration, Instant},
27};
28use thiserror::Error;
29use walkdir::{DirEntry, WalkDir};
30
31/// Error that occurs during test execution
32#[derive(Debug, Error)]
33#[error("Path: {path}\nName: {name}\nError: {kind}")]
34pub struct TestError {
35    pub name: String,
36    pub path: String,
37    pub kind: TestErrorKind,
38}
39
40/// Specific kind of error that occurred during test execution
41#[derive(Debug, Error)]
42pub enum TestErrorKind {
43    #[error("logs root mismatch: got {got}, expected {expected}")]
44    LogsRootMismatch { got: B256, expected: B256 },
45    #[error("state root mismatch: got {got}, expected {expected}")]
46    StateRootMismatch { got: B256, expected: B256 },
47    #[error("unknown private key: {0:?}")]
48    UnknownPrivateKey(B256),
49    #[error("unexpected exception: got {got_exception:?}, expected {expected_exception:?}")]
50    UnexpectedException {
51        expected_exception: Option<String>,
52        got_exception: Option<String>,
53    },
54    #[error("unexpected output: got {got_output:?}, expected {expected_output:?}")]
55    UnexpectedOutput {
56        expected_output: Option<Bytes>,
57        got_output: Option<Bytes>,
58    },
59    #[error(transparent)]
60    SerdeDeserialize(#[from] serde_json::Error),
61    #[error("thread panicked")]
62    Panic,
63    #[error("path does not exist")]
64    InvalidPath,
65    #[error("no JSON test files found in path")]
66    NoJsonFiles,
67}
68
69/// Find all JSON test files in the given path
70/// If path is a file, returns it in a vector
71/// If path is a directory, recursively finds all .json files
72pub fn find_all_json_tests(path: &Path) -> Vec<PathBuf> {
73    if path.is_file() {
74        vec![path.to_path_buf()]
75    } else {
76        WalkDir::new(path)
77            .into_iter()
78            .filter_map(Result::ok)
79            .filter(|e| e.path().extension() == Some("json".as_ref()))
80            .map(DirEntry::into_path)
81            .collect()
82    }
83}
84
85/// Check if a test should be skipped based on its filename
86/// Some tests are known to be problematic or take too long
87fn skip_test(path: &Path) -> bool {
88    let path_str = path.to_str().unwrap_or_default();
89
90    // Skip tets that have storage for newly created account.
91    if path_str.contains("paris/eip7610_create_collision") {
92        return true;
93    }
94
95    let name = path.file_name().unwrap().to_str().unwrap_or_default();
96
97    matches!(
98        name,
99        // Test check if gas price overflows, we handle this correctly but does not match tests specific exception.
100        | "CreateTransactionHighNonce.json"
101
102        // Test with some storage check.
103        | "RevertInCreateInInit_Paris.json"
104        | "RevertInCreateInInit.json"
105        | "dynamicAccountOverwriteEmpty.json"
106        | "dynamicAccountOverwriteEmpty_Paris.json"
107        | "RevertInCreateInInitCreate2Paris.json"
108        | "create2collisionStorage.json"
109        | "RevertInCreateInInitCreate2.json"
110        | "create2collisionStorageParis.json"
111        | "InitCollision.json"
112        | "InitCollisionParis.json"
113        | "test_init_collision_create_opcode.json"
114
115        // Malformed value.
116        | "ValueOverflow.json"
117        | "ValueOverflowParis.json"
118
119        // These tests are passing, but they take a lot of time to execute so we are going to skip them.
120        | "Call50000_sha256.json"
121        | "static_Call50000_sha256.json"
122        | "loopMul.json"
123        | "CALLBlake2f_MaxRounds.json"
124    )
125}
126
127struct TestExecutionContext<'a> {
128    name: &'a str,
129    unit: &'a TestUnit,
130    test: &'a Test,
131    cfg: &'a CfgEnv,
132    block: &'a BlockEnv,
133    tx: &'a TxEnv,
134    cache_state: &'a database::CacheState,
135    elapsed: &'a Arc<Mutex<Duration>>,
136    trace: bool,
137    print_json_outcome: bool,
138}
139
140struct DebugContext<'a> {
141    name: &'a str,
142    path: &'a str,
143    index: usize,
144    test: &'a Test,
145    cfg: &'a CfgEnv,
146    block: &'a BlockEnv,
147    tx: &'a TxEnv,
148    cache_state: &'a database::CacheState,
149    error: &'a TestErrorKind,
150}
151
152fn build_json_output(
153    test: &Test,
154    test_name: &str,
155    exec_result: &Result<ExecutionResult<HaltReason>, EVMError<Infallible, InvalidTransaction>>,
156    validation: &TestValidationResult,
157    spec: SpecId,
158    error: Option<String>,
159) -> serde_json::Value {
160    json!({
161        "stateRoot": validation.state_root,
162        "logsRoot": validation.logs_root,
163        "output": exec_result.as_ref().ok().and_then(|r| r.output().cloned()).unwrap_or_default(),
164        "gasUsed": exec_result.as_ref().ok().map(|r| r.gas_used()).unwrap_or_default(),
165        "pass": error.is_none(),
166        "errorMsg": error.unwrap_or_default(),
167        "evmResult": format_evm_result(exec_result),
168        "postLogsHash": validation.logs_root,
169        "fork": spec,
170        "test": test_name,
171        "d": test.indexes.data,
172        "g": test.indexes.gas,
173        "v": test.indexes.value,
174    })
175}
176
177fn format_evm_result(
178    exec_result: &Result<ExecutionResult<HaltReason>, EVMError<Infallible, InvalidTransaction>>,
179) -> String {
180    match exec_result {
181        Ok(r) => match r {
182            ExecutionResult::Success { reason, .. } => format!("Success: {reason:?}"),
183            ExecutionResult::Revert { .. } => "Revert".to_string(),
184            ExecutionResult::Halt { reason, .. } => format!("Halt: {reason:?}"),
185        },
186        Err(e) => e.to_string(),
187    }
188}
189
190fn validate_exception(
191    test: &Test,
192    exec_result: &Result<ExecutionResult<HaltReason>, EVMError<Infallible, InvalidTransaction>>,
193) -> Result<bool, TestErrorKind> {
194    match (&test.expect_exception, exec_result) {
195        (None, Ok(_)) => Ok(false), // No exception expected, execution succeeded
196        (Some(_), Err(_)) => Ok(true), // Exception expected and occurred
197        _ => Err(TestErrorKind::UnexpectedException {
198            expected_exception: test.expect_exception.clone(),
199            got_exception: exec_result.as_ref().err().map(|e| e.to_string()),
200        }),
201    }
202}
203
204fn validate_output(
205    expected_output: Option<&Bytes>,
206    actual_result: &ExecutionResult<HaltReason>,
207) -> Result<(), TestErrorKind> {
208    if let Some((expected, actual)) = expected_output.zip(actual_result.output()) {
209        if expected != actual {
210            return Err(TestErrorKind::UnexpectedOutput {
211                expected_output: Some(expected.clone()),
212                got_output: actual_result.output().cloned(),
213            });
214        }
215    }
216    Ok(())
217}
218
219fn check_evm_execution(
220    test: &Test,
221    expected_output: Option<&Bytes>,
222    test_name: &str,
223    exec_result: &Result<ExecutionResult<HaltReason>, EVMError<Infallible, InvalidTransaction>>,
224    db: &mut database::State<EmptyDB>,
225    spec: SpecId,
226    print_json_outcome: bool,
227) -> Result<(), TestErrorKind> {
228    let validation = compute_test_roots(exec_result, db);
229
230    let print_json = |error: Option<&TestErrorKind>| {
231        if print_json_outcome {
232            let json = build_json_output(
233                test,
234                test_name,
235                exec_result,
236                &validation,
237                spec,
238                error.map(|e| e.to_string()),
239            );
240            eprintln!("{json}");
241        }
242    };
243
244    // Check if exception handling is correct
245    let exception_expected = validate_exception(test, exec_result).inspect_err(|e| {
246        print_json(Some(e));
247    })?;
248
249    // If exception was expected and occurred, we're done
250    if exception_expected {
251        print_json(None);
252        return Ok(());
253    }
254
255    // Validate output if execution succeeded
256    if let Ok(result) = exec_result {
257        validate_output(expected_output, result).inspect_err(|e| {
258            print_json(Some(e));
259        })?;
260    }
261
262    // Validate logs root
263    if validation.logs_root != test.logs {
264        let error = TestErrorKind::LogsRootMismatch {
265            got: validation.logs_root,
266            expected: test.logs,
267        };
268        print_json(Some(&error));
269        return Err(error);
270    }
271
272    // Validate state root
273    if validation.state_root != test.hash {
274        let error = TestErrorKind::StateRootMismatch {
275            got: validation.state_root,
276            expected: test.hash,
277        };
278        print_json(Some(&error));
279        return Err(error);
280    }
281
282    print_json(None);
283    Ok(())
284}
285
286/// Execute a single test suite file containing multiple tests
287///
288/// # Arguments
289/// * `path` - Path to the JSON test file
290/// * `elapsed` - Shared counter for total execution time
291/// * `trace` - Whether to enable EVM tracing
292/// * `print_json_outcome` - Whether to print JSON formatted results
293pub fn execute_test_suite(
294    path: &Path,
295    elapsed: &Arc<Mutex<Duration>>,
296    trace: bool,
297    print_json_outcome: bool,
298) -> Result<(), TestError> {
299    if skip_test(path) {
300        return Ok(());
301    }
302
303    let s = std::fs::read_to_string(path).unwrap();
304    let path = path.to_string_lossy().into_owned();
305    let suite: TestSuite = serde_json::from_str(&s).map_err(|e| TestError {
306        name: "Unknown".to_string(),
307        path: path.clone(),
308        kind: e.into(),
309    })?;
310
311    for (name, unit) in suite.0 {
312        // Prepare initial state
313        let cache_state = unit.state();
314
315        // Setup base configuration
316        let mut cfg = CfgEnv::default();
317        cfg.chain_id = unit
318            .env
319            .current_chain_id
320            .unwrap_or(U256::ONE)
321            .try_into()
322            .unwrap_or(1);
323
324        // Post and execution
325        for (spec_name, tests) in &unit.post {
326            // Skip Constantinople spec
327            if *spec_name == SpecName::Constantinople {
328                continue;
329            }
330
331            cfg.spec = spec_name.to_spec_id();
332
333            // Configure max blobs per spec
334            if cfg.spec.is_enabled_in(SpecId::OSAKA) {
335                cfg.set_max_blobs_per_tx(6);
336            } else if cfg.spec.is_enabled_in(SpecId::PRAGUE) {
337                cfg.set_max_blobs_per_tx(9);
338            } else {
339                cfg.set_max_blobs_per_tx(6);
340            }
341
342            // Setup block environment for this spec
343            let block = unit.block_env(&mut cfg);
344
345            for (index, test) in tests.iter().enumerate() {
346                // Setup transaction environment
347                let tx = match test.tx_env(&unit) {
348                    Ok(tx) => tx,
349                    Err(_) if test.expect_exception.is_some() => continue,
350                    Err(_) => {
351                        return Err(TestError {
352                            name: name.clone(),
353                            path: path.clone(),
354                            kind: TestErrorKind::UnknownPrivateKey(unit.transaction.secret_key),
355                        });
356                    }
357                };
358
359                // Execute the test
360                let result = execute_single_test(TestExecutionContext {
361                    name: &name,
362                    unit: &unit,
363                    test,
364                    cfg: &cfg,
365                    block: &block,
366                    tx: &tx,
367                    cache_state: &cache_state,
368                    elapsed,
369                    trace,
370                    print_json_outcome,
371                });
372
373                if let Err(e) = result {
374                    // Handle error with debug trace if needed
375                    static FAILED: AtomicBool = AtomicBool::new(false);
376                    if print_json_outcome || FAILED.swap(true, Ordering::SeqCst) {
377                        return Err(TestError {
378                            name: name.clone(),
379                            path: path.clone(),
380                            kind: e,
381                        });
382                    }
383
384                    // Re-run with trace for debugging
385                    debug_failed_test(DebugContext {
386                        name: &name,
387                        path: &path,
388                        index,
389                        test,
390                        cfg: &cfg,
391                        block: &block,
392                        tx: &tx,
393                        cache_state: &cache_state,
394                        error: &e,
395                    });
396
397                    return Err(TestError {
398                        path: path.clone(),
399                        name: name.clone(),
400                        kind: e,
401                    });
402                }
403            }
404        }
405    }
406    Ok(())
407}
408
409fn execute_single_test(ctx: TestExecutionContext) -> Result<(), TestErrorKind> {
410    // Prepare state
411    let mut cache = ctx.cache_state.clone();
412    cache.set_state_clear_flag(ctx.cfg.spec.is_enabled_in(SpecId::SPURIOUS_DRAGON));
413    let mut state = database::State::builder()
414        .with_cached_prestate(cache)
415        .with_bundle_update()
416        .build();
417
418    let evm_context = Context::mainnet()
419        .with_block(ctx.block)
420        .with_tx(ctx.tx)
421        .with_cfg(ctx.cfg)
422        .with_db(&mut state);
423
424    // Execute
425    let timer = Instant::now();
426    let (db, exec_result) = if ctx.trace {
427        let mut evm = evm_context
428            .build_mainnet_with_inspector(TracerEip3155::buffered(stderr()).without_summary());
429        let res = evm.inspect_tx_commit(ctx.tx);
430        let db = evm.ctx.journaled_state.database;
431        (db, res)
432    } else {
433        let mut evm = evm_context.build_mainnet();
434        let res = evm.transact_commit(ctx.tx);
435        let db = evm.ctx.journaled_state.database;
436        (db, res)
437    };
438    *ctx.elapsed.lock().unwrap() += timer.elapsed();
439
440    // Check results
441    check_evm_execution(
442        ctx.test,
443        ctx.unit.out.as_ref(),
444        ctx.name,
445        &exec_result,
446        db,
447        ctx.cfg.spec(),
448        ctx.print_json_outcome,
449    )
450}
451
452fn debug_failed_test(ctx: DebugContext) {
453    println!("\nTraces:");
454
455    // Re-run with tracing
456    let mut cache = ctx.cache_state.clone();
457    cache.set_state_clear_flag(ctx.cfg.spec.is_enabled_in(SpecId::SPURIOUS_DRAGON));
458    let mut state = database::State::builder()
459        .with_cached_prestate(cache)
460        .with_bundle_update()
461        .build();
462
463    let mut evm = Context::mainnet()
464        .with_db(&mut state)
465        .with_block(ctx.block)
466        .with_tx(ctx.tx)
467        .with_cfg(ctx.cfg)
468        .build_mainnet_with_inspector(TracerEip3155::buffered(stderr()).without_summary());
469
470    let exec_result = evm.inspect_tx_commit(ctx.tx);
471
472    println!("\nExecution result: {exec_result:#?}");
473    println!("\nExpected exception: {:?}", ctx.test.expect_exception);
474    println!("\nState before:\n{}", ctx.cache_state.pretty_print());
475    println!(
476        "\nState after:\n{}",
477        evm.ctx.journaled_state.database.cache.pretty_print()
478    );
479    println!("\nSpecification: {:?}", ctx.cfg.spec);
480    println!("\nTx: {:#?}", ctx.tx);
481    println!("Block: {:#?}", ctx.block);
482    println!("Cfg: {:#?}", ctx.cfg);
483    println!(
484        "\nTest name: {:?} (index: {}, path: {:?}) failed:\n{}",
485        ctx.name, ctx.index, ctx.path, ctx.error
486    );
487}
488
489#[derive(Clone, Copy)]
490struct TestRunnerConfig {
491    single_thread: bool,
492    trace: bool,
493    print_outcome: bool,
494    keep_going: bool,
495}
496
497impl TestRunnerConfig {
498    fn new(single_thread: bool, trace: bool, print_outcome: bool, keep_going: bool) -> Self {
499        // Trace implies print_outcome
500        let print_outcome = print_outcome || trace;
501        // print_outcome or trace implies single_thread
502        let single_thread = single_thread || print_outcome;
503
504        Self {
505            single_thread,
506            trace,
507            print_outcome,
508            keep_going,
509        }
510    }
511}
512
513#[derive(Clone)]
514struct TestRunnerState {
515    n_errors: Arc<AtomicUsize>,
516    console_bar: Arc<ProgressBar>,
517    queue: Arc<Mutex<(usize, Vec<PathBuf>)>>,
518    elapsed: Arc<Mutex<Duration>>,
519}
520
521impl TestRunnerState {
522    fn new(test_files: Vec<PathBuf>) -> Self {
523        let n_files = test_files.len();
524        Self {
525            n_errors: Arc::new(AtomicUsize::new(0)),
526            console_bar: Arc::new(ProgressBar::with_draw_target(
527                Some(n_files as u64),
528                ProgressDrawTarget::stdout(),
529            )),
530            queue: Arc::new(Mutex::new((0usize, test_files))),
531            elapsed: Arc::new(Mutex::new(Duration::ZERO)),
532        }
533    }
534
535    fn next_test(&self) -> Option<PathBuf> {
536        let (current_idx, queue) = &mut *self.queue.lock().unwrap();
537        let idx = *current_idx;
538        let test_path = queue.get(idx).cloned()?;
539        *current_idx = idx + 1;
540        Some(test_path)
541    }
542}
543
544fn run_test_worker(state: TestRunnerState, config: TestRunnerConfig) -> Result<(), TestError> {
545    loop {
546        if !config.keep_going && state.n_errors.load(Ordering::SeqCst) > 0 {
547            return Ok(());
548        }
549
550        let Some(test_path) = state.next_test() else {
551            return Ok(());
552        };
553
554        let result = execute_test_suite(
555            &test_path,
556            &state.elapsed,
557            config.trace,
558            config.print_outcome,
559        );
560
561        state.console_bar.inc(1);
562
563        if let Err(err) = result {
564            state.n_errors.fetch_add(1, Ordering::SeqCst);
565            if !config.keep_going {
566                return Err(err);
567            }
568        }
569    }
570}
571
572fn determine_thread_count(single_thread: bool, n_files: usize) -> usize {
573    match (single_thread, std::thread::available_parallelism()) {
574        (true, _) | (false, Err(_)) => 1,
575        (false, Ok(n)) => n.get().min(n_files),
576    }
577}
578
579/// Run all test files in parallel or single-threaded mode
580///
581/// # Arguments
582/// * `test_files` - List of test files to execute
583/// * `single_thread` - Force single-threaded execution
584/// * `trace` - Enable EVM execution tracing
585/// * `print_outcome` - Print test outcomes in JSON format
586/// * `keep_going` - Continue running tests even if some fail
587pub fn run(
588    test_files: Vec<PathBuf>,
589    single_thread: bool,
590    trace: bool,
591    print_outcome: bool,
592    keep_going: bool,
593) -> Result<(), TestError> {
594    let config = TestRunnerConfig::new(single_thread, trace, print_outcome, keep_going);
595    let n_files = test_files.len();
596    let state = TestRunnerState::new(test_files);
597    let num_threads = determine_thread_count(config.single_thread, n_files);
598
599    // Spawn worker threads
600    let mut handles = Vec::with_capacity(num_threads);
601    for i in 0..num_threads {
602        let state = state.clone();
603
604        let thread = std::thread::Builder::new()
605            .name(format!("runner-{i}"))
606            .spawn(move || run_test_worker(state, config))
607            .unwrap();
608
609        handles.push(thread);
610    }
611
612    // Collect results from all threads
613    let mut thread_errors = Vec::new();
614    for (i, handle) in handles.into_iter().enumerate() {
615        match handle.join() {
616            Ok(Ok(())) => {}
617            Ok(Err(e)) => thread_errors.push(e),
618            Err(_) => thread_errors.push(TestError {
619                name: format!("thread {i} panicked"),
620                path: String::new(),
621                kind: TestErrorKind::Panic,
622            }),
623        }
624    }
625
626    state.console_bar.finish();
627
628    // Print summary
629    println!(
630        "Finished execution. Total CPU time: {:.6}s",
631        state.elapsed.lock().unwrap().as_secs_f64()
632    );
633
634    let n_errors = state.n_errors.load(Ordering::SeqCst);
635    let n_thread_errors = thread_errors.len();
636
637    if n_errors == 0 && n_thread_errors == 0 {
638        println!("All tests passed!");
639        Ok(())
640    } else {
641        println!("Encountered {n_errors} errors out of {n_files} total tests");
642
643        if n_thread_errors == 0 {
644            std::process::exit(1);
645        }
646
647        if n_thread_errors > 1 {
648            println!("{n_thread_errors} threads returned an error, out of {num_threads} total:");
649            for error in &thread_errors {
650                println!("{error}");
651            }
652        }
653        Err(thread_errors.swap_remove(0))
654    }
655}