revm_bytecode/legacy/
jump_map.rs1use bitvec::vec::BitVec;
2use core::{
3 cmp::Ordering,
4 hash::{Hash, Hasher},
5};
6use primitives::{hex, Bytes, OnceLock};
7use std::{fmt::Debug, sync::Arc};
8
9#[derive(Clone, Eq)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize))]
14pub struct JumpTable {
15 #[cfg_attr(feature = "serde", serde(skip))]
17 table_ptr: *const u8,
18 len: usize,
20 table: Arc<Bytes>,
22}
23
24unsafe impl Send for JumpTable {}
26unsafe impl Sync for JumpTable {}
27
28impl PartialEq for JumpTable {
29 fn eq(&self, other: &Self) -> bool {
30 self.table.eq(&other.table)
31 }
32}
33
34impl Hash for JumpTable {
35 fn hash<H: Hasher>(&self, state: &mut H) {
36 self.table.hash(state);
37 }
38}
39
40impl PartialOrd for JumpTable {
41 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Ord for JumpTable {
47 fn cmp(&self, other: &Self) -> Ordering {
48 self.table.cmp(&other.table)
49 }
50}
51
52impl Debug for JumpTable {
53 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54 f.debug_struct("JumpTable")
55 .field("map", &hex::encode(self.table.as_ref()))
56 .finish()
57 }
58}
59
60impl Default for JumpTable {
61 #[inline]
62 fn default() -> Self {
63 static DEFAULT: OnceLock<JumpTable> = OnceLock::new();
64 DEFAULT.get_or_init(|| Self::new(BitVec::default())).clone()
65 }
66}
67
68#[cfg(feature = "serde")]
69impl<'de> serde::Deserialize<'de> for JumpTable {
70 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
71 where
72 D: serde::Deserializer<'de>,
73 {
74 #[derive(serde::Deserialize)]
75 struct JumpTableSerde {
76 len: usize,
77 table: Arc<Bytes>,
78 }
79
80 let data = JumpTableSerde::deserialize(deserializer)?;
81 Ok(Self::from_bytes_arc(data.table, data.len))
82 }
83}
84
85impl JumpTable {
86 #[inline]
90 pub fn new(jumps: BitVec<u8>) -> Self {
91 let bit_len = jumps.len();
92 let bytes = jumps.into_vec().into();
93 Self::from_bytes(bytes, bit_len)
94 }
95
96 #[inline]
98 pub fn as_slice(&self) -> &[u8] {
99 &self.table
100 }
101
102 #[inline]
104 pub fn len(&self) -> usize {
105 self.len
106 }
107
108 #[inline]
110 pub fn is_empty(&self) -> bool {
111 self.len == 0
112 }
113
114 #[inline]
124 pub fn from_slice(slice: &[u8], bit_len: usize) -> Self {
125 Self::from_bytes(Bytes::from(slice.to_vec()), bit_len)
126 }
127
128 #[inline]
134 pub fn from_bytes(bytes: Bytes, bit_len: usize) -> Self {
135 Self::from_bytes_arc(Arc::new(bytes), bit_len)
136 }
137
138 #[inline]
144 pub fn from_bytes_arc(table: Arc<Bytes>, bit_len: usize) -> Self {
145 const BYTE_LEN: usize = 8;
146 assert!(
147 table.len() * BYTE_LEN >= bit_len,
148 "slice bit length {} is less than bit_len {}",
149 table.len() * BYTE_LEN,
150 bit_len
151 );
152
153 let table_ptr = table.as_ptr();
154
155 Self {
156 table_ptr,
157 table,
158 len: bit_len,
159 }
160 }
161
162 #[inline]
165 pub fn is_valid(&self, pc: usize) -> bool {
166 pc < self.len && unsafe { *self.table_ptr.add(pc >> 3) & (1 << (pc & 7)) != 0 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 #[should_panic(expected = "slice bit length 8 is less than bit_len 10")]
176 fn test_jump_table_from_slice_panic() {
177 let slice = &[0x00];
178 let _ = JumpTable::from_slice(slice, 10);
179 }
180
181 #[test]
182 fn test_jump_table_from_slice() {
183 let slice = &[0x00];
184 let jumptable = JumpTable::from_slice(slice, 3);
185 assert_eq!(jumptable.len, 3);
186 }
187
188 #[test]
189 fn test_is_valid() {
190 let jump_table = JumpTable::from_slice(&[0x0D, 0x06], 13);
191
192 assert_eq!(jump_table.len, 13);
193
194 assert!(jump_table.is_valid(0)); assert!(!jump_table.is_valid(1));
196 assert!(jump_table.is_valid(2)); assert!(jump_table.is_valid(3)); assert!(!jump_table.is_valid(4));
199 assert!(!jump_table.is_valid(5));
200 assert!(!jump_table.is_valid(6));
201 assert!(!jump_table.is_valid(7));
202 assert!(!jump_table.is_valid(8));
203 assert!(jump_table.is_valid(9)); assert!(jump_table.is_valid(10)); assert!(!jump_table.is_valid(11));
206 assert!(!jump_table.is_valid(12));
207 }
208
209 #[test]
210 #[cfg(feature = "serde")]
211 fn test_serde_roundtrip() {
212 let original = JumpTable::from_slice(&[0x0D, 0x06], 13);
213
214 let serialized = serde_json::to_string(&original).expect("Failed to serialize");
216
217 let deserialized: JumpTable =
219 serde_json::from_str(&serialized).expect("Failed to deserialize");
220
221 assert_eq!(original.len, deserialized.len);
223 assert_eq!(original.table, deserialized.table);
224
225 for i in 0..13 {
227 assert_eq!(
228 original.is_valid(i),
229 deserialized.is_valid(i),
230 "Mismatch at index {i}"
231 );
232 }
233 }
234}
235
236#[cfg(test)]
237mod bench_is_valid {
238 use super::*;
239 use std::{sync::Arc, time::Instant};
240
241 const ITERATIONS: usize = 1_000_000;
242 const TEST_SIZE: usize = 10_000;
243
244 fn create_test_table() -> BitVec<u8> {
245 let mut bitvec = BitVec::from_vec(vec![0u8; TEST_SIZE.div_ceil(8)]);
246 bitvec.resize(TEST_SIZE, false);
247 for i in (0..TEST_SIZE).step_by(3) {
248 bitvec.set(i, true);
249 }
250 bitvec
251 }
252
253 #[derive(Clone)]
254 pub(super) struct JumpTableWithArcDeref(pub Arc<BitVec<u8>>);
255
256 impl JumpTableWithArcDeref {
257 #[inline]
258 pub(super) fn is_valid(&self, pc: usize) -> bool {
259 pc < self.0.len() && unsafe { *self.0.get_unchecked(pc) }
260 }
261 }
262
263 fn benchmark_implementation<F>(name: &str, table: &F, test_fn: impl Fn(&F, usize) -> bool)
264 where
265 F: Clone,
266 {
267 for i in 0..10_000 {
269 std::hint::black_box(test_fn(table, i % TEST_SIZE));
270 }
271
272 let start = Instant::now();
273 let mut count = 0;
274
275 for i in 0..ITERATIONS {
276 if test_fn(table, i % TEST_SIZE) {
277 count += 1;
278 }
279 }
280
281 let duration = start.elapsed();
282 let ns_per_op = duration.as_nanos() as f64 / ITERATIONS as f64;
283 let ops_per_sec = ITERATIONS as f64 / duration.as_secs_f64();
284
285 println!("{name} Performance:");
286 println!(" Time per op: {ns_per_op:.2} ns");
287 println!(" Ops per sec: {ops_per_sec:.0}");
288 println!(" True count: {count}");
289 println!();
290
291 std::hint::black_box(count);
292 }
293
294 #[test]
295 fn bench_is_valid() {
296 println!("JumpTable is_valid() Benchmark Comparison");
297 println!("=========================================");
298
299 let bitvec = create_test_table();
300
301 let cached_table = JumpTable::new(bitvec.clone());
303 benchmark_implementation("JumpTable (Cached Pointer)", &cached_table, |table, pc| {
304 table.is_valid(pc)
305 });
306
307 let arc_table = JumpTableWithArcDeref(Arc::new(bitvec));
309 benchmark_implementation("JumpTableWithArcDeref (Arc)", &arc_table, |table, pc| {
310 table.is_valid(pc)
311 });
312
313 println!("Benchmark completed successfully!");
314 }
315
316 #[test]
317 fn bench_different_access_patterns() {
318 let bitvec = create_test_table();
319 let cached_table = JumpTable::new(bitvec.clone());
320 let arc_table = JumpTableWithArcDeref(Arc::new(bitvec));
321
322 println!("Access Pattern Comparison");
323 println!("========================");
324
325 let start = Instant::now();
327 for i in 0..ITERATIONS {
328 std::hint::black_box(cached_table.is_valid(i % TEST_SIZE));
329 }
330 let cached_sequential = start.elapsed();
331
332 let start = Instant::now();
333 for i in 0..ITERATIONS {
334 std::hint::black_box(arc_table.is_valid(i % TEST_SIZE));
335 }
336 let arc_sequential = start.elapsed();
337
338 let start = Instant::now();
340 for i in 0..ITERATIONS {
341 std::hint::black_box(cached_table.is_valid((i * 17) % TEST_SIZE));
342 }
343 let cached_random = start.elapsed();
344
345 let start = Instant::now();
346 for i in 0..ITERATIONS {
347 std::hint::black_box(arc_table.is_valid((i * 17) % TEST_SIZE));
348 }
349 let arc_random = start.elapsed();
350
351 println!("Sequential Access:");
352 println!(
353 " Cached: {:.2} ns/op",
354 cached_sequential.as_nanos() as f64 / ITERATIONS as f64
355 );
356 println!(
357 " Arc: {:.2} ns/op",
358 arc_sequential.as_nanos() as f64 / ITERATIONS as f64
359 );
360 println!(
361 " Speedup: {:.1}x",
362 arc_sequential.as_nanos() as f64 / cached_sequential.as_nanos() as f64
363 );
364
365 println!();
366 println!("Random Access:");
367 println!(
368 " Cached: {:.2} ns/op",
369 cached_random.as_nanos() as f64 / ITERATIONS as f64
370 );
371 println!(
372 " Arc: {:.2} ns/op",
373 arc_random.as_nanos() as f64 / ITERATIONS as f64
374 );
375 println!(
376 " Speedup: {:.1}x",
377 arc_random.as_nanos() as f64 / cached_random.as_nanos() as f64
378 );
379 }
380}