diff --git a/prefetcher_reverse/src/ip_tool.rs b/prefetcher_reverse/src/ip_tool.rs index 346c030..dcbf0f8 100644 --- a/prefetcher_reverse/src/ip_tool.rs +++ b/prefetcher_reverse/src/ip_tool.rs @@ -2,9 +2,12 @@ use bitvec::prelude::*; use cache_utils::mmap::MMappedMemory; use lazy_static::lazy_static; use std::collections::LinkedList; +use std::ptr::copy_nonoverlapping; use std::sync::Mutex; struct WXRange { + start: usize, + end: usize, // points to the last valid byte bitmap: BitVec, // fixme bit vector pages: Vec>, } @@ -50,8 +53,91 @@ const TIMED_CLFLUSH: FunctionTemplate = FunctionTemplate { end: timed_clflush_template_end as *const u8, }; -pub fn allocate_function(align: usize, template: FunctionTemplate) -> Function { - unimplemented!() +impl WXRange { + unsafe fn allocate( + &mut self, + align: usize, + offset: usize, + length: usize, + mask: usize, + round_mask: usize, + ) -> Result<*mut u8, ()> { + // In each range, we want to find base = 2^a * k such that start <= base + align < start + 2^a + // This can be done with k = ceil(start - align / 2^a). + // 2^a * k can likely be computed with some clever bit tricks. + // \o/ + let start = self.start; + let mut candidate = (start - offset + mask) & round_mask + offset; + assert_eq!(candidate & mask, offset); + assert!(candidate > start); + while candidate + length <= self.end { + let bit_range = &mut self.bitmap[(candidate - start)..(candidate - start + length)]; + if !bit_range.any() { + bit_range.set_all(true); + return Ok(candidate as *mut u8); + } + candidate += align; + } + Err(()) + } +} + +impl WXAllocator { + pub unsafe fn allocate( + &mut self, + align: usize, + offset: usize, + length: usize, + ) -> Result<*mut u8, ()> { + if align.count_ones() != 1 && offset < align { + return Err(()); // FIXME Error type. + } + let mask = align - 1; + let round_mask = !mask; + for range in self.ranges.iter_mut() { + if let Ok(p) = unsafe { range.allocate(align, offset, length, mask, round_mask) } { + return Ok(p); + } + } + // Now we need to allocate a new page ^^' + return Err(()); + } +} + +impl Function { + pub fn try_new( + align: usize, + offset: usize, + template: FunctionTemplate, + ) -> Result { + // find suitable target + let mut allocator = wx_allocator.lock().unwrap(); + if align.count_ones() != 1 && offset < align { + return Err(()); // FIXME Error type. + } + let mask = align - 1; + let real_offset = (offset - (template.ip as usize) + (template.start as usize)) & mask; + let length = (template.end as usize) - (template.start as usize); + + let p = unsafe { allocator.allocate(align, real_offset, length) }?; + unsafe { copy_nonoverlapping(template.start as *const u8, p, length) }; + let res = Function { + fun: unsafe { + std::mem::transmute::<*mut u8, unsafe extern "C" fn(*const u8) -> u64>(p) + }, + ip: unsafe { p.add(template.ip as usize - template.start as usize) }, + end: unsafe { p.add(length) }, + size: length, + }; + Ok(res) + } +} + +impl Drop for Function { + fn drop(&mut self) { + // Find the correct range, and deallocate all the bits + todo!() + } } global_asm!(