diff --git a/Cargo.lock b/Cargo.lock index fcc36ee..011c08d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,7 @@ dependencies = [ "nix", "polling_serial", "static_assertions", + "turn_lock", "vga_buffer", "x86_64", ] diff --git a/cache_utils/Cargo.toml b/cache_utils/Cargo.toml index 9a204c1..4f08f6d 100644 --- a/cache_utils/Cargo.toml +++ b/cache_utils/Cargo.toml @@ -18,9 +18,10 @@ atomic = { version = "0.5.0" } nix = { version = "0.18.0", optional = true } libc = { version = "0.2.77", optional = true } hashbrown = { version = "0.9.1", optional = true } +turn_lock = { path = "../turn_lock", optional = true} [features] -use_std = ["nix", "itertools/use_std", "libc", "cpuid/use_std"] +use_std = ["nix", "itertools/use_std", "libc", "cpuid/use_std", "turn_lock"] no_std = ["polling_serial", "vga_buffer", "hashbrown"] default = ["use_std"] diff --git a/cache_utils/src/calibrate_2t.rs b/cache_utils/src/calibrate_2t.rs index 1cef30a..7eedabf 100644 --- a/cache_utils/src/calibrate_2t.rs +++ b/cache_utils/src/calibrate_2t.rs @@ -1,4 +1,424 @@ -use crate::calibration::{CalibrateResult, CalibrateResult2T, HashMap, ASVP}; +use crate::calibration::Verbosity::{RawResult, Thresholds}; +use crate::calibration::{ + get_cache_slicing, get_vpn, CalibrateResult, CalibrationOptions, HashMap, ASVP, + SPURIOUS_THRESHOLD, +}; +use core::arch::x86_64 as arch_x86; +use itertools::Itertools; +use nix::sched::{sched_getaffinity, sched_setaffinity, CpuSet}; +use nix::unistd::Pid; +use nix::Error; +use std::cmp::min; +use std::ptr::null_mut; +//use std::sync::atomic::Ordering; +use std::mem::forget; +use std::sync::{Arc, Mutex}; +use std::thread; +use turn_lock::TurnHandle; + +pub struct CalibrateOperation2T<'a> { + pub prepare: unsafe fn(*const u8) -> (), + pub op: unsafe fn(*const u8) -> u64, + pub name: &'a str, + pub display_name: &'a str, +} + +pub struct CalibrateResult2T { + pub main_core: usize, + pub helper_core: usize, + pub res: Result, nix::Error>, // TODO + + // TODO +} + +pub unsafe fn calibrate_fixed_freq_2_thread>( + p: *const u8, + increment: usize, + len: isize, + cores: &mut I, + operations: &[CalibrateOperation2T], + options: CalibrationOptions, + core_per_socket: u8, +) -> Vec { + calibrate_fixed_freq_2_thread_impl( + p, + increment, + len, + cores, + operations, + options, + core_per_socket, + ) +} + +const OPTIMISED_ADDR_ITER_FACTOR: u32 = 16; + +struct HelperThreadParams { + stop: bool, + op: unsafe fn(*const u8), + address: *const u8, +} + +// TODO : Add the optimised address support +// TODO : Modularisation / factorisation of some of the common code with the single threaded no_std version ? + +#[cfg(feature = "use_std")] +fn calibrate_fixed_freq_2_thread_impl>( + p: *const u8, + increment: usize, + len: isize, + cores: &mut I, + operations: &[CalibrateOperation2T], + mut options: CalibrationOptions, + core_per_socket: u8, +) -> Vec { + if options.verbosity >= Thresholds { + println!( + "Calibrating {}...", + operations + .iter() + .map(|operation| { operation.display_name }) + .format(", ") + ); + } + + let bucket_size = options.hist_params.bucket_size; + + let to_bucket = |time: u64| -> usize { time as usize / bucket_size }; + let from_bucket = |bucket: usize| -> u64 { (bucket * bucket_size) as u64 }; + + let slicing = get_cache_slicing(core_per_socket); + + let h = if let Some(s) = slicing { + if s.can_hash() { + Some(|addr: usize| -> u8 { slicing.unwrap().hash(addr).unwrap() }) + } else { + None + } + } else { + None + }; + + let mut ret = Vec::new(); + + let mut turn_handles = TurnHandle::new( + 2, + HelperThreadParams { + stop: true, + op: operations[0].prepare, + address: null_mut(), + }, + ); + + let mut helper_turn_handle = Arc::new(Mutex::new(turn_handles.pop().unwrap())); + let mut main_turn_handle = turn_handles.pop().unwrap(); + + let mut params = main_turn_handle.wait(); + + if options.verbosity >= Thresholds { + print!("CSV: main_core, helper_core, address, "); + if h.is_some() { + print!("hash, "); + } + println!( + "{} min, {} median, {} max", + operations + .iter() + .map(|operation| operation.name) + .format(" min, "), + operations + .iter() + .map(|operation| operation.name) + .format(" median, "), + operations + .iter() + .map(|operation| operation.name) + .format(" max, ") + ); + } + + if options.verbosity >= RawResult { + print!("RESULT:main_core,helper_core,address,"); + if h.is_some() { + print!("hash,"); + } + println!( + "time,{}", + operations + .iter() + .map(|operation| operation.name) + .format(",") + ); + } + + let image_antecedent = match slicing { + Some(s) => s.image_antecedent(len as usize - 1), + None => None, + }; + if image_antecedent.is_some() { + options.hist_params.iterations *= OPTIMISED_ADDR_ITER_FACTOR; + } + + let old = sched_getaffinity(Pid::from_raw(0)).unwrap(); + + for (main_core, helper_core) in cores { + // set main thread affinity + + if options.verbosity >= Thresholds { + println!( + "Calibration for main_core {}, helper {}.", + main_core, helper_core + ); + + eprintln!( + "Calibration for main_core {}, helper {}.", + main_core, helper_core + ); + } + + let mut core = CpuSet::new(); + match core.set(main_core) { + Ok(_) => {} + Err(e) => { + ret.push(CalibrateResult2T { + main_core, + helper_core, + res: Err(e), + }); + continue; + } + } + + match sched_setaffinity(Pid::from_raw(0), &core) { + Ok(_) => {} + Err(e) => { + ret.push(CalibrateResult2T { + main_core, + helper_core, + res: Err(e), + }); + continue; + } + } + + let helper_thread = if helper_core != main_core { + params.stop = false; + // set up the helper thread + + let hc = helper_core; + let th = helper_turn_handle.clone(); + Some(thread::spawn(move || { + calibrate_fixed_freq_2_thread_helper(th, hc) + })) + } else { + None + }; + // do the calibration + let mut calibrate_result_vec = Vec::new(); + + let offsets: Box> = match image_antecedent { + Some(ref ima) => Box::new(ima.values().copied()), + None => Box::new((0..len as isize).step_by(increment)), + }; + + for i in offsets { + let pointer = unsafe { p.offset(i) }; + params.address = pointer; + + let hash = h.map(|h| h(pointer as usize)); + + if options.verbosity >= Thresholds { + print!("Calibration for {:p}", pointer); + if let Some(h) = hash { + print!(" (hash: {:x})", h) + } + println!(); + } + + // TODO add some useful impl to CalibrateResults + let mut calibrate_result = CalibrateResult { + page: get_vpn(pointer), + offset: i, + histogram: Vec::new(), + median: vec![0; operations.len()], + min: vec![0; operations.len()], + max: vec![0; operations.len()], + }; + calibrate_result.histogram.reserve(operations.len()); + + if helper_core != main_core { + for op in operations { + params.op = op.prepare; + let mut hist = vec![0; options.hist_params.bucket_number]; + for _ in 0..options.hist_params.iterations { + params = main_turn_handle.wait(); + let _time = unsafe { (op.op)(pointer) }; + } + for _ in 0..options.hist_params.iterations { + //params.next(); + params = main_turn_handle.wait(); + let time = unsafe { (op.op)(pointer) }; + let bucket = min(options.hist_params.bucket_number - 1, to_bucket(time)); + hist[bucket] += 1; + } + calibrate_result.histogram.push(hist); + } + } else { + for op in operations { + let mut hist = vec![0; options.hist_params.bucket_number]; + for _ in 0..options.hist_params.iterations { + unsafe { (op.prepare)(pointer) }; + unsafe { arch_x86::_mm_mfence() }; // Test with this ? + let _time = unsafe { (op.op)(pointer) }; + } + for _ in 0..options.hist_params.iterations { + unsafe { (op.prepare)(pointer) }; + unsafe { arch_x86::_mm_mfence() }; // Test with this ? + let time = unsafe { (op.op)(pointer) }; + let bucket = min(options.hist_params.bucket_number - 1, to_bucket(time)); + hist[bucket] += 1; + } + calibrate_result.histogram.push(hist); + } + } + let mut sums = vec![0; operations.len()]; + + let median_thresholds: Vec = calibrate_result + .histogram + .iter() + .map(|h| { + (options.hist_params.iterations - h[options.hist_params.bucket_number - 1]) / 2 + }) + .collect(); + + for j in 0..options.hist_params.bucket_number - 1 { + if options.verbosity >= RawResult { + print!("RESULT:{},{},{:p},", main_core, helper_core, pointer); + if let Some(h) = hash { + print!("{:x},", h); + } + print!("{}", from_bucket(j)); + } + // ignore the last bucket : spurious context switches etc. + for op in 0..operations.len() { + let hist = &calibrate_result.histogram[op][j]; + let min = &mut calibrate_result.min[op]; + let max = &mut calibrate_result.max[op]; + let med = &mut calibrate_result.median[op]; + let sum = &mut sums[op]; + if options.verbosity >= RawResult { + print!(",{}", hist); + } + + if *min == 0 { + // looking for min + if *hist > SPURIOUS_THRESHOLD { + *min = from_bucket(j); + } + } else if *hist > SPURIOUS_THRESHOLD { + *max = from_bucket(j); + } + + if *med == 0 { + *sum += *hist; + if *sum >= median_thresholds[op] { + *med = from_bucket(j); + } + } + } + if options.verbosity >= RawResult { + println!(); + } + } + if options.verbosity >= Thresholds { + for (j, op) in operations.iter().enumerate() { + println!( + "{}: min {}, median {}, max {}", + op.display_name, + calibrate_result.min[j], + calibrate_result.median[j], + calibrate_result.max[j] + ); + } + print!("CSV: {},{},{:p}, ", main_core, helper_core, pointer); + if let Some(h) = hash { + print!("{:x}, ", h) + } + println!( + "{}, {}, {}", + calibrate_result.min.iter().format(", "), + calibrate_result.median.iter().format(", "), + calibrate_result.max.iter().format(", ") + ); + } + calibrate_result_vec.push(calibrate_result); + } + + ret.push(CalibrateResult2T { + main_core, + helper_core, + res: Ok(calibrate_result_vec), + }); + + if helper_core != main_core { + // terminate the thread + params.stop = true; + params.next(); + params = main_turn_handle.wait(); + // join thread. + helper_thread.unwrap().join(); + // FIXME error handling + } + } + + sched_setaffinity(Pid::from_raw(0), &old).unwrap(); + + ret + // return the result + // TODO +} + +fn calibrate_fixed_freq_2_thread_helper( + turn_handle: Arc>>, + helper_core: usize, +) -> Result<(), Error> { + let mut turn_handle = turn_handle.lock().unwrap(); + // set thread affinity + let mut core = CpuSet::new(); + match core.set(helper_core) { + Ok(_) => {} + Err(e) => { + let mut params = turn_handle.wait(); + params.stop = true; + params.next(); + return Err(e); + } + } + + match sched_setaffinity(Pid::from_raw(0), &core) { + Ok(_) => {} + Err(_e) => { + unimplemented!(); + } + } + + loop { + // grab lock + let params = turn_handle.wait(); + if params.stop { + params.next(); + return Ok(()); + } + // get the relevant parameters + let addr: *const u8 = params.address; + let op = params.op; + unsafe { op(addr) }; + // release lock + params.next() + } +} + +// ------------------- Analysis ------------------ pub fn calibration_result_to_ASVP T>( results: Vec, diff --git a/cache_utils/src/calibration.rs b/cache_utils/src/calibration.rs index caec494..f6b41c4 100644 --- a/cache_utils/src/calibration.rs +++ b/cache_utils/src/calibration.rs @@ -9,10 +9,6 @@ use core::arch::x86_64 as arch_x86; #[cfg(feature = "no_std")] use polling_serial::{serial_print as print, serial_println as println}; -//#[cfg(feature = "use_std")] -//use nix::errno::Errno; -#[cfg(feature = "use_std")] -use nix::sched::{sched_getaffinity, sched_setaffinity, CpuSet}; #[cfg(feature = "use_std")] use nix::unistd::Pid; //#[cfg(feature = "use_std")] @@ -288,7 +284,7 @@ pub unsafe fn calibrate( ) } -const SPURIOUS_THRESHOLD: u32 = 1; +pub const SPURIOUS_THRESHOLD: u32 = 1; fn calibrate_impl_fixed_freq( p: *const u8, increment: usize, @@ -480,55 +476,6 @@ fn calibrate_impl_fixed_freq( ret } -#[cfg(feature = "use_std")] -pub struct CalibrateOperation2T<'a> { - pub prepare: unsafe fn(*const u8) -> (), - pub op: unsafe fn(*const u8) -> u64, - pub name: &'a str, - pub display_name: &'a str, -} - -#[cfg(feature = "use_std")] -pub struct CalibrateResult2T { - pub main_core: usize, - pub helper_core: usize, - pub res: Result, nix::Error>, // TODO - - // TODO -} - -fn wait(turn_lock: &AtomicBool, turn: bool) { - while turn_lock.load(Ordering::Acquire) != turn { - spin_loop_hint(); - } - assert_eq!(turn_lock.load(Ordering::Relaxed), turn); -} - -fn next(turn_lock: &AtomicBool) { - turn_lock.fetch_xor(true, Ordering::Release); -} - -#[cfg(feature = "use_std")] -pub unsafe fn calibrate_fixed_freq_2_thread>( - p: *const u8, - increment: usize, - len: isize, - cores: &mut I, - operations: &[CalibrateOperation2T], - options: CalibrationOptions, - core_per_socket: u8, -) -> Vec { - calibrate_fixed_freq_2_thread_impl( - p, - increment, - len, - cores, - operations, - options, - core_per_socket, - ) -} - pub fn get_cache_slicing(core_per_socket: u8) -> Option { if let Some(uarch) = MicroArchitecture::get_micro_architecture() { if let Some(vendor_family_model_stepping) = MicroArchitecture::get_family_model_stepping() { @@ -547,366 +494,6 @@ pub fn get_cache_slicing(core_per_socket: u8) -> Option { } } -const OPTIMISED_ADDR_ITER_FACTOR: u32 = 16; - -// TODO : Add the optimised address support -// TODO : Modularisation / factorisation of some of the common code with the single threaded no_std version ? - -#[cfg(feature = "use_std")] -fn calibrate_fixed_freq_2_thread_impl>( - p: *const u8, - increment: usize, - len: isize, - cores: &mut I, - operations: &[CalibrateOperation2T], - mut options: CalibrationOptions, - core_per_socket: u8, -) -> Vec { - if options.verbosity >= Thresholds { - println!( - "Calibrating {}...", - operations - .iter() - .map(|operation| { operation.display_name }) - .format(", ") - ); - } - - let bucket_size = options.hist_params.bucket_size; - - let to_bucket = |time: u64| -> usize { time as usize / bucket_size }; - let from_bucket = |bucket: usize| -> u64 { (bucket * bucket_size) as u64 }; - - let slicing = get_cache_slicing(core_per_socket); - - let h = if let Some(s) = slicing { - if s.can_hash() { - Some(|addr: usize| -> u8 { slicing.unwrap().hash(addr).unwrap() }) - } else { - None - } - } else { - None - }; - - let mut ret = Vec::new(); - - let helper_thread_params = Arc::new(HelperThreadParams { - turn: AtomicBool::new(false), - stop: AtomicBool::new(true), - op: Atomic::new(operations[0].prepare), - address: AtomicPtr::new(null_mut()), - }); - - if options.verbosity >= Thresholds { - print!("CSV: main_core, helper_core, address, "); - if h.is_some() { - print!("hash, "); - } - println!( - "{} min, {} median, {} max", - operations - .iter() - .map(|operation| operation.name) - .format(" min, "), - operations - .iter() - .map(|operation| operation.name) - .format(" median, "), - operations - .iter() - .map(|operation| operation.name) - .format(" max, ") - ); - } - - if options.verbosity >= RawResult { - print!("RESULT:main_core,helper_core,address,"); - if h.is_some() { - print!("hash,"); - } - println!( - "time,{}", - operations - .iter() - .map(|operation| operation.name) - .format(",") - ); - } - - let image_antecedent = match slicing { - Some(s) => s.image_antecedent(len as usize - 1), - None => None, - }; - if image_antecedent.is_some() { - options.hist_params.iterations *= OPTIMISED_ADDR_ITER_FACTOR; - } - - let old = sched_getaffinity(Pid::from_raw(0)).unwrap(); - - for (main_core, helper_core) in cores { - // set main thread affinity - - if options.verbosity >= Thresholds { - println!( - "Calibration for main_core {}, helper {}.", - main_core, helper_core - ); - - eprintln!( - "Calibration for main_core {}, helper {}.", - main_core, helper_core - ); - } - - let mut core = CpuSet::new(); - match core.set(main_core) { - Ok(_) => {} - Err(e) => { - ret.push(CalibrateResult2T { - main_core, - helper_core, - res: Err(e), - }); - continue; - } - } - - match sched_setaffinity(Pid::from_raw(0), &core) { - Ok(_) => {} - Err(e) => { - ret.push(CalibrateResult2T { - main_core, - helper_core, - res: Err(e), - }); - continue; - } - } - - let helper_thread = if helper_core != main_core { - helper_thread_params.stop.store(false, Ordering::Relaxed); - // set up the helper thread - - let htp = helper_thread_params.clone(); - let hc = helper_core; - Some(thread::spawn(move || { - calibrate_fixed_freq_2_thread_helper(htp, hc) - })) - } else { - None - }; - // do the calibration - let mut calibrate_result_vec = Vec::new(); - - let offsets: Box> = match image_antecedent { - Some(ref ima) => Box::new(ima.values().copied()), - None => Box::new((0..len as isize).step_by(increment)), - }; - - for i in offsets { - let pointer = unsafe { p.offset(i) }; - helper_thread_params - .address - .store(pointer as *mut u8, Ordering::Relaxed); - - let hash = h.map(|h| h(pointer as usize)); - - if options.verbosity >= Thresholds { - print!("Calibration for {:p}", pointer); - if let Some(h) = hash { - print!(" (hash: {:x})", h) - } - println!(); - } - - // TODO add some useful impl to CalibrateResults - let mut calibrate_result = CalibrateResult { - page: get_vpn(pointer), - offset: i, - histogram: Vec::new(), - median: vec![0; operations.len()], - min: vec![0; operations.len()], - max: vec![0; operations.len()], - }; - calibrate_result.histogram.reserve(operations.len()); - - if helper_core != main_core { - for op in operations { - helper_thread_params.op.store(op.prepare, Ordering::Relaxed); - let mut hist = vec![0; options.hist_params.bucket_number]; - for _ in 0..options.hist_params.iterations { - next(&helper_thread_params.turn); - wait(&helper_thread_params.turn, false); - let _time = unsafe { (op.op)(pointer) }; - } - for _ in 0..options.hist_params.iterations { - next(&helper_thread_params.turn); - wait(&helper_thread_params.turn, false); - let time = unsafe { (op.op)(pointer) }; - let bucket = min(options.hist_params.bucket_number - 1, to_bucket(time)); - hist[bucket] += 1; - } - calibrate_result.histogram.push(hist); - } - } else { - for op in operations { - let mut hist = vec![0; options.hist_params.bucket_number]; - for _ in 0..options.hist_params.iterations { - unsafe { (op.prepare)(pointer) }; - unsafe { arch_x86::_mm_mfence() }; // Test with this ? - let _time = unsafe { (op.op)(pointer) }; - } - for _ in 0..options.hist_params.iterations { - unsafe { (op.prepare)(pointer) }; - unsafe { arch_x86::_mm_mfence() }; // Test with this ? - let time = unsafe { (op.op)(pointer) }; - let bucket = min(options.hist_params.bucket_number - 1, to_bucket(time)); - hist[bucket] += 1; - } - calibrate_result.histogram.push(hist); - } - } - let mut sums = vec![0; operations.len()]; - - let median_thresholds: Vec = calibrate_result - .histogram - .iter() - .map(|h| { - (options.hist_params.iterations - h[options.hist_params.bucket_number - 1]) / 2 - }) - .collect(); - - for j in 0..options.hist_params.bucket_number - 1 { - if options.verbosity >= RawResult { - print!("RESULT:{},{},{:p},", main_core, helper_core, pointer); - if let Some(h) = hash { - print!("{:x},", h); - } - print!("{}", from_bucket(j)); - } - // ignore the last bucket : spurious context switches etc. - for op in 0..operations.len() { - let hist = &calibrate_result.histogram[op][j]; - let min = &mut calibrate_result.min[op]; - let max = &mut calibrate_result.max[op]; - let med = &mut calibrate_result.median[op]; - let sum = &mut sums[op]; - if options.verbosity >= RawResult { - print!(",{}", hist); - } - - if *min == 0 { - // looking for min - if *hist > SPURIOUS_THRESHOLD { - *min = from_bucket(j); - } - } else if *hist > SPURIOUS_THRESHOLD { - *max = from_bucket(j); - } - - if *med == 0 { - *sum += *hist; - if *sum >= median_thresholds[op] { - *med = from_bucket(j); - } - } - } - if options.verbosity >= RawResult { - println!(); - } - } - if options.verbosity >= Thresholds { - for (j, op) in operations.iter().enumerate() { - println!( - "{}: min {}, median {}, max {}", - op.display_name, - calibrate_result.min[j], - calibrate_result.median[j], - calibrate_result.max[j] - ); - } - print!("CSV: {},{},{:p}, ", main_core, helper_core, pointer); - if let Some(h) = hash { - print!("{:x}, ", h) - } - println!( - "{}, {}, {}", - calibrate_result.min.iter().format(", "), - calibrate_result.median.iter().format(", "), - calibrate_result.max.iter().format(", ") - ); - } - calibrate_result_vec.push(calibrate_result); - } - - ret.push(CalibrateResult2T { - main_core, - helper_core, - res: Ok(calibrate_result_vec), - }); - - if helper_core != main_core { - // terminate the thread - helper_thread_params.stop.store(true, Ordering::Relaxed); - next(&helper_thread_params.turn); - wait(&helper_thread_params.turn, false); - // join thread. - helper_thread.unwrap().join(); - } - } - - sched_setaffinity(Pid::from_raw(0), &old).unwrap(); - - ret - // return the result - // TODO -} -#[cfg(feature = "use_std")] -struct HelperThreadParams { - turn: AtomicBool, - stop: AtomicBool, - op: Atomic, - address: AtomicPtr, -} - -#[cfg(feature = "use_std")] -fn calibrate_fixed_freq_2_thread_helper( - params: Arc, - helper_core: usize, -) -> Result<(), Error> { - // set thread affinity - let mut core = CpuSet::new(); - match core.set(helper_core) { - Ok(_) => {} - Err(_e) => { - unimplemented!(); - } - } - - match sched_setaffinity(Pid::from_raw(0), &core) { - Ok(_) => {} - Err(_e) => { - unimplemented!(); - } - } - - loop { - // grab lock - wait(¶ms.turn, true); - if params.stop.load(Ordering::Relaxed) { - next(¶ms.turn); - return Ok(()); - } - // get the relevant parameters - let addr: *const u8 = params.address.load(Ordering::Relaxed); - let op = params.op.load(Ordering::Relaxed); - unsafe { op(addr) }; - // release lock - next(¶ms.turn); - } -} - #[allow(non_snake_case)] pub fn calibrate_L3_miss_hit( array: &[u8], diff --git a/covert_channels_evaluation/src/lib.rs b/covert_channels_evaluation/src/lib.rs index 85e3301..4de5e81 100644 --- a/covert_channels_evaluation/src/lib.rs +++ b/covert_channels_evaluation/src/lib.rs @@ -27,14 +27,19 @@ use std::fmt::Debug; use std::sync::Arc; use std::thread; +/* TODO : replace page with a handle type, + require exclusive handle access, + Handle protected by the turn lock +*/ /** * Safety considerations : Not ensure thread safety, need proper locking as needed. */ -pub trait CovertChannel: Send + Sync + CoreSpec + Debug { +pub trait CovertChnel: Send + Sync + CoreSpec + Debug { + type Handle; const BIT_PER_PAGE: usize; - unsafe fn transmit(&self, page: *const u8, bits: &mut BitIterator); - unsafe fn receive(&self, page: *const u8) -> Vec; - unsafe fn ready_page(&mut self, page: *const u8); + unsafe fn transmit(&self, handle: &mut Handle, bits: &mut BitIterator); + unsafe fn receive(&self, handle: &mut Handle) -> Vec; + unsafe fn ready_page(&mut self, page: *const u8) -> Handle; } #[derive(Debug)] diff --git a/turn_lock/src/lib.rs b/turn_lock/src/lib.rs index 19e30c6..90ff755 100644 --- a/turn_lock/src/lib.rs +++ b/turn_lock/src/lib.rs @@ -4,6 +4,9 @@ use std::ops::{Deref, DerefMut}; use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; use std::sync::Arc; +// FIXME There may be significant unsafety if wait is called twice ? +// Add some extra mutual exclusion ? + pub struct RawTurnLock { turn: AtomicUsize, num_turns: usize, @@ -90,14 +93,14 @@ impl TurnHandle { result } - unsafe fn guard(&mut self) -> TurnLockGuard { + unsafe fn guard(&self) -> TurnLockGuard { TurnLockGuard { handle: &*self, marker: PhantomData, } } - pub fn wait(&mut self) -> TurnLockGuard { + pub fn wait(&self) -> TurnLockGuard { unsafe { self.raw.lock.wait(self.index) }; // Safety: the turn lock is now held unsafe { self.guard() } @@ -143,6 +146,7 @@ impl<'a, T> DerefMut for TurnLockGuard<'a, T> { } } +unsafe impl Send for TurnHandle {} #[cfg(test)] mod tests { use crate::TurnHandle; @@ -156,13 +160,13 @@ mod tests { fn three_turns() { let mut v = TurnHandle::<()>::new(3, ()); let t0 = v[0].wait(); - t0.next(); + drop(t0); let t1 = v[1].wait(); - t1.next(); + drop(t1); let t2 = v[2].wait(); - t2.next(); + drop(t2); let t0 = v[0].wait(); - t0.next(); + drop(t0); //assert_eq!(v[2].current(), 1); } }