revm_bytecode/legacy/
jump_map.rs

1use bitvec::vec::BitVec;
2use core::{
3    cmp::Ordering,
4    hash::{Hash, Hasher},
5};
6use primitives::{hex, Bytes, OnceLock};
7use std::{fmt::Debug, sync::Arc};
8
9/// A table of valid `jump` destinations.
10///
11/// It is immutable, cheap to clone and memory efficient, with one bit per byte in the bytecode.
12#[derive(Clone, Eq)]
13pub struct JumpTable {
14    /// Cached pointer to table data to avoid Arc overhead on lookup
15    table_ptr: *const u8,
16    /// Number of bits in the table.
17    len: usize,
18    /// Actual bit vec
19    table: Arc<Bytes>,
20}
21
22// SAFETY: BitVec data is immutable through Arc, pointer won't be invalidated
23unsafe impl Send for JumpTable {}
24unsafe impl Sync for JumpTable {}
25
26impl PartialEq for JumpTable {
27    fn eq(&self, other: &Self) -> bool {
28        self.table.eq(&other.table)
29    }
30}
31
32impl Hash for JumpTable {
33    fn hash<H: Hasher>(&self, state: &mut H) {
34        self.table.hash(state);
35    }
36}
37
38impl PartialOrd for JumpTable {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for JumpTable {
45    fn cmp(&self, other: &Self) -> Ordering {
46        self.table.cmp(&other.table)
47    }
48}
49
50impl Debug for JumpTable {
51    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52        f.debug_struct("JumpTable")
53            .field("map", &hex::encode(self.table.as_ref()))
54            .finish()
55    }
56}
57
58impl Default for JumpTable {
59    #[inline]
60    fn default() -> Self {
61        static DEFAULT: OnceLock<JumpTable> = OnceLock::new();
62        DEFAULT.get_or_init(|| Self::new(BitVec::default())).clone()
63    }
64}
65
66#[cfg(feature = "serde")]
67impl serde::Serialize for JumpTable {
68    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69    where
70        S: serde::Serializer,
71    {
72        let mut bitvec = BitVec::<u8>::from_vec(self.table.to_vec());
73        bitvec.resize(self.len, false);
74        bitvec.serialize(serializer)
75    }
76}
77
78#[cfg(feature = "serde")]
79impl<'de> serde::Deserialize<'de> for JumpTable {
80    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
81    where
82        D: serde::Deserializer<'de>,
83    {
84        let bitvec = BitVec::deserialize(deserializer)?;
85        Ok(Self::new(bitvec))
86    }
87}
88
89impl JumpTable {
90    /// Create new JumpTable directly from an existing BitVec.
91    ///
92    /// Uses [`Self::from_bytes`] internally.
93    #[inline]
94    pub fn new(jumps: BitVec<u8>) -> Self {
95        let bit_len = jumps.len();
96        let bytes = jumps.into_vec().into();
97        Self::from_bytes(bytes, bit_len)
98    }
99
100    /// Gets the raw bytes of the jump map.
101    #[inline]
102    pub fn as_slice(&self) -> &[u8] {
103        &self.table
104    }
105
106    /// Gets the length of the jump map.
107    #[inline]
108    pub fn len(&self) -> usize {
109        self.len
110    }
111
112    /// Returns true if the jump map is empty.
113    #[inline]
114    pub fn is_empty(&self) -> bool {
115        self.len == 0
116    }
117
118    /// Constructs a jump map from raw bytes and length.
119    ///
120    /// Bit length represents number of used bits inside slice.
121    ///
122    /// Uses [`Self::from_bytes`] internally.
123    ///
124    /// # Panics
125    ///
126    /// Panics if number of bits in slice is less than bit_len.
127    #[inline]
128    pub fn from_slice(slice: &[u8], bit_len: usize) -> Self {
129        Self::from_bytes(Bytes::from(slice.to_vec()), bit_len)
130    }
131
132    /// Create new JumpTable directly from an existing Bytes.
133    ///
134    /// Bit length represents number of used bits inside slice.
135    ///
136    /// Panics if bytes length is less than bit_len * 8.
137    #[inline]
138    pub fn from_bytes(bytes: Bytes, bit_len: usize) -> Self {
139        Self::from_bytes_arc(Arc::new(bytes), bit_len)
140    }
141
142    /// Create new JumpTable directly from an existing Bytes.
143    ///
144    /// Bit length represents number of used bits inside slice.
145    ///
146    /// Panics if bytes length is less than bit_len * 8.
147    #[inline]
148    pub fn from_bytes_arc(table: Arc<Bytes>, bit_len: usize) -> Self {
149        const BYTE_LEN: usize = 8;
150        assert!(
151            table.len() * BYTE_LEN >= bit_len,
152            "slice bit length {} is less than bit_len {}",
153            table.len() * BYTE_LEN,
154            bit_len
155        );
156
157        let table_ptr = table.as_ptr();
158
159        Self {
160            table_ptr,
161            table,
162            len: bit_len,
163        }
164    }
165
166    /// Checks if `pc` is a valid jump destination.
167    /// Uses cached pointer and bit operations for faster access
168    #[inline]
169    pub fn is_valid(&self, pc: usize) -> bool {
170        pc < self.len && unsafe { *self.table_ptr.add(pc >> 3) & (1 << (pc & 7)) != 0 }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    #[should_panic(expected = "slice bit length 8 is less than bit_len 10")]
180    fn test_jump_table_from_slice_panic() {
181        let slice = &[0x00];
182        let _ = JumpTable::from_slice(slice, 10);
183    }
184
185    #[test]
186    fn test_jump_table_from_slice() {
187        let slice = &[0x00];
188        let jumptable = JumpTable::from_slice(slice, 3);
189        assert_eq!(jumptable.len, 3);
190    }
191
192    #[test]
193    fn test_is_valid() {
194        let jump_table = JumpTable::from_slice(&[0x0D, 0x06], 13);
195
196        assert_eq!(jump_table.len, 13);
197
198        assert!(jump_table.is_valid(0)); // valid
199        assert!(!jump_table.is_valid(1));
200        assert!(jump_table.is_valid(2)); // valid
201        assert!(jump_table.is_valid(3)); // valid
202        assert!(!jump_table.is_valid(4));
203        assert!(!jump_table.is_valid(5));
204        assert!(!jump_table.is_valid(6));
205        assert!(!jump_table.is_valid(7));
206        assert!(!jump_table.is_valid(8));
207        assert!(jump_table.is_valid(9)); // valid
208        assert!(jump_table.is_valid(10)); // valid
209        assert!(!jump_table.is_valid(11));
210        assert!(!jump_table.is_valid(12));
211    }
212
213    #[test]
214    #[cfg(feature = "serde")]
215    fn test_serde_legacy_format() {
216        let legacy_format = r#"
217        {
218            "order": "bitvec::order::Lsb0",
219            "head": {
220                "width": 8,
221                "index": 0
222            },
223            "bits": 4,
224            "data": [5]
225        }"#;
226
227        let table: JumpTable = serde_json::from_str(legacy_format).expect("Failed to deserialize");
228        assert_eq!(table.len, 4);
229        assert!(table.is_valid(0));
230        assert!(!table.is_valid(1));
231        assert!(table.is_valid(2));
232        assert!(!table.is_valid(3));
233    }
234
235    #[test]
236    #[cfg(feature = "serde")]
237    fn test_serde_roundtrip() {
238        let original = JumpTable::from_slice(&[0x0D, 0x06], 13);
239
240        // Serialize to JSON
241        let serialized = serde_json::to_string(&original).expect("Failed to serialize");
242
243        // Deserialize from JSON
244        let deserialized: JumpTable =
245            serde_json::from_str(&serialized).expect("Failed to deserialize");
246
247        // Check that the deserialized table matches the original
248        assert_eq!(original.len, deserialized.len);
249        assert_eq!(original.table, deserialized.table);
250
251        // Verify functionality is preserved
252        for i in 0..13 {
253            assert_eq!(
254                original.is_valid(i),
255                deserialized.is_valid(i),
256                "Mismatch at index {i}"
257            );
258        }
259    }
260}
261
262#[cfg(test)]
263mod bench_is_valid {
264    use super::*;
265    use std::{sync::Arc, time::Instant};
266
267    const ITERATIONS: usize = 1_000_000;
268    const TEST_SIZE: usize = 10_000;
269
270    fn create_test_table() -> BitVec<u8> {
271        let mut bitvec = BitVec::from_vec(vec![0u8; TEST_SIZE.div_ceil(8)]);
272        bitvec.resize(TEST_SIZE, false);
273        for i in (0..TEST_SIZE).step_by(3) {
274            bitvec.set(i, true);
275        }
276        bitvec
277    }
278
279    #[derive(Clone)]
280    pub(super) struct JumpTableWithArcDeref(pub Arc<BitVec<u8>>);
281
282    impl JumpTableWithArcDeref {
283        #[inline]
284        pub(super) fn is_valid(&self, pc: usize) -> bool {
285            pc < self.0.len() && unsafe { *self.0.get_unchecked(pc) }
286        }
287    }
288
289    fn benchmark_implementation<F>(name: &str, table: &F, test_fn: impl Fn(&F, usize) -> bool)
290    where
291        F: Clone,
292    {
293        // Warmup
294        for i in 0..10_000 {
295            std::hint::black_box(test_fn(table, i % TEST_SIZE));
296        }
297
298        let start = Instant::now();
299        let mut count = 0;
300
301        for i in 0..ITERATIONS {
302            if test_fn(table, i % TEST_SIZE) {
303                count += 1;
304            }
305        }
306
307        let duration = start.elapsed();
308        let ns_per_op = duration.as_nanos() as f64 / ITERATIONS as f64;
309        let ops_per_sec = ITERATIONS as f64 / duration.as_secs_f64();
310
311        println!("{name} Performance:");
312        println!("  Time per op: {ns_per_op:.2} ns");
313        println!("  Ops per sec: {ops_per_sec:.0}");
314        println!("  True count: {count}");
315        println!();
316
317        std::hint::black_box(count);
318    }
319
320    #[test]
321    fn bench_is_valid() {
322        println!("JumpTable is_valid() Benchmark Comparison");
323        println!("=========================================");
324
325        let bitvec = create_test_table();
326
327        // Test cached pointer implementation
328        let cached_table = JumpTable::new(bitvec.clone());
329        benchmark_implementation("JumpTable (Cached Pointer)", &cached_table, |table, pc| {
330            table.is_valid(pc)
331        });
332
333        // Test Arc deref implementation
334        let arc_table = JumpTableWithArcDeref(Arc::new(bitvec));
335        benchmark_implementation("JumpTableWithArcDeref (Arc)", &arc_table, |table, pc| {
336            table.is_valid(pc)
337        });
338
339        println!("Benchmark completed successfully!");
340    }
341
342    #[test]
343    fn bench_different_access_patterns() {
344        let bitvec = create_test_table();
345        let cached_table = JumpTable::new(bitvec.clone());
346        let arc_table = JumpTableWithArcDeref(Arc::new(bitvec));
347
348        println!("Access Pattern Comparison");
349        println!("========================");
350
351        // Sequential access
352        let start = Instant::now();
353        for i in 0..ITERATIONS {
354            std::hint::black_box(cached_table.is_valid(i % TEST_SIZE));
355        }
356        let cached_sequential = start.elapsed();
357
358        let start = Instant::now();
359        for i in 0..ITERATIONS {
360            std::hint::black_box(arc_table.is_valid(i % TEST_SIZE));
361        }
362        let arc_sequential = start.elapsed();
363
364        // Random access
365        let start = Instant::now();
366        for i in 0..ITERATIONS {
367            std::hint::black_box(cached_table.is_valid((i * 17) % TEST_SIZE));
368        }
369        let cached_random = start.elapsed();
370
371        let start = Instant::now();
372        for i in 0..ITERATIONS {
373            std::hint::black_box(arc_table.is_valid((i * 17) % TEST_SIZE));
374        }
375        let arc_random = start.elapsed();
376
377        println!("Sequential Access:");
378        println!(
379            "  Cached: {:.2} ns/op",
380            cached_sequential.as_nanos() as f64 / ITERATIONS as f64
381        );
382        println!(
383            "  Arc:    {:.2} ns/op",
384            arc_sequential.as_nanos() as f64 / ITERATIONS as f64
385        );
386        println!(
387            "  Speedup: {:.1}x",
388            arc_sequential.as_nanos() as f64 / cached_sequential.as_nanos() as f64
389        );
390
391        println!();
392        println!("Random Access:");
393        println!(
394            "  Cached: {:.2} ns/op",
395            cached_random.as_nanos() as f64 / ITERATIONS as f64
396        );
397        println!(
398            "  Arc:    {:.2} ns/op",
399            arc_random.as_nanos() as f64 / ITERATIONS as f64
400        );
401        println!(
402            "  Speedup: {:.1}x",
403            arc_random.as_nanos() as f64 / cached_random.as_nanos() as f64
404        );
405    }
406}