diff --git a/covert_channels_evaluation/src/lib.rs b/covert_channels_evaluation/src/lib.rs index 410803e..85e3301 100644 --- a/covert_channels_evaluation/src/lib.rs +++ b/covert_channels_evaluation/src/lib.rs @@ -61,7 +61,14 @@ impl CovertChannelBenchmarkResult { } pub fn csv(&self) -> String { - format!("{},{},{},{},{}", self.num_bytes_transmitted, self.num_bit_errors, self.error_rate, self.time_rdtsc, self.time_seconds.as_nanos()) + format!( + "{},{},{},{},{}", + self.num_bytes_transmitted, + self.num_bit_errors, + self.error_rate, + self.time_rdtsc, + self.time_seconds.as_nanos() + ) } pub fn csv_header() -> String { diff --git a/turn_lock/src/lib.rs b/turn_lock/src/lib.rs index f00cef1..19e30c6 100644 --- a/turn_lock/src/lib.rs +++ b/turn_lock/src/lib.rs @@ -1,52 +1,151 @@ +use std::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; use std::sync::Arc; -pub struct TurnLock { - turn: Arc, - index: usize, +pub struct RawTurnLock { + turn: AtomicUsize, num_turns: usize, } -impl TurnLock { - pub fn new(num_turns: usize) -> Vec { - let turn = Arc::new(AtomicUsize::new(0)); - let mut r = Vec::new(); - for i in 0..num_turns { - r.push(TurnLock { - turn: turn.clone(), - index: i, - num_turns, - }) - } - r - } - pub fn wait(&mut self) { - while self.turn.load(Ordering::Acquire) != self.index { - spin_loop_hint(); - } - assert_eq!(self.turn.load(Ordering::Relaxed), self.index); - } - pub fn next(&mut self) { - assert_eq!(self.turn.load(Ordering::Relaxed), self.index); - let r = self.turn.compare_exchange( - self.index, - (self.index + 1) % self.num_turns, - Ordering::Release, - Ordering::Relaxed, - ); - if r.expect("Failed to release lock") != self.index { - panic!("Released lock out of turn"); +impl RawTurnLock { + pub fn new(num_turns: usize) -> Self { + Self { + turn: AtomicUsize::new(0), + num_turns, } } - pub fn current(&self) -> usize { - self.turn.load(Ordering::SeqCst) + pub fn is_poisoned(&self) -> bool { + let current = self.turn.load(Ordering::Relaxed); + current >= self.num_turns + } + + pub unsafe fn try_wait(&self, turn: usize) -> bool { + let current = self.turn.load(Ordering::Acquire); + current == turn + } + + pub unsafe fn wait(&self, turn: usize) { + let mut current = self.turn.load(Ordering::Acquire); + while current < self.num_turns && current != turn { + spin_loop_hint(); + current = self.turn.load(Ordering::Acquire); + } + if current >= self.num_turns { + panic!("Waiting on a poisoned turn lock"); + } + if self.turn.load(Ordering::Relaxed) != turn { + panic!("Someone stole the turn"); + } + } + + pub unsafe fn next(&self, turn: usize) { + if self.is_poisoned() { + panic!("Using poisoned turn lock"); + } + let current = self.turn.load(Ordering::Relaxed); + if current != turn { + panic!("Releasing turn lock out of turn"); + } + + let r = self.turn.compare_exchange( + turn, + (turn + 1) % self.num_turns, + Ordering::Release, + Ordering::Relaxed, + ); + if r.expect("Failed to release turn lock") != turn { + panic!("Released turn lock out of turn"); + } + } +} + +struct TurnLockData { + pub lock: RawTurnLock, + pub data: UnsafeCell, +} + +pub struct TurnHandle { + raw: Arc>, + index: usize, +} + +impl TurnHandle { + pub fn new(num_turns: usize, data: T) -> Vec> { + let turn_lock = RawTurnLock::new(num_turns); + let turn_lock_data = TurnLockData { + lock: turn_lock, + data: UnsafeCell::new(data), + }; + let arc = Arc::new(turn_lock_data); + let mut result = Vec::with_capacity(num_turns); + for i in 0..num_turns { + result.push(Self { + raw: arc.clone(), + index: i, + }) + } + result + } + + unsafe fn guard(&mut self) -> TurnLockGuard { + TurnLockGuard { + handle: &*self, + marker: PhantomData, + } + } + + pub fn wait(&mut self) -> TurnLockGuard { + unsafe { self.raw.lock.wait(self.index) }; + // Safety: the turn lock is now held + unsafe { self.guard() } + } + + unsafe fn next(&self) { + self.raw.lock.next(self.index); + } +} + +#[must_use = "if unused the TurnLock will immediately unlock"] +pub struct TurnLockGuard<'a, T> { + handle: &'a TurnHandle, + marker: PhantomData<&'a T>, +} + +impl<'a, T> TurnLockGuard<'a, T> { + pub fn next(self) { + drop(self) + } + + pub fn handle(&self) -> &TurnHandle { + self.handle + } +} +impl<'a, T> Drop for TurnLockGuard<'a, T> { + fn drop(&mut self) { + unsafe { self.handle.next() }; + } +} + +impl<'a, T> Deref for TurnLockGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.handle.raw.data.get() } + } +} + +impl<'a, T> DerefMut for TurnLockGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.handle.raw.data.get() } } } #[cfg(test)] mod tests { - use crate::TurnLock; + use crate::TurnHandle; #[test] fn it_works() { @@ -55,15 +154,15 @@ mod tests { #[test] fn three_turns() { - let mut v = TurnLock::new(3); - v[0].wait(); - v[0].next(); - v[1].wait(); - v[1].next(); - v[2].wait(); - v[2].next(); - v[0].wait(); - v[0].next(); - assert_eq!(v[2].current(), 1); + let mut v = TurnHandle::<()>::new(3, ()); + let t0 = v[0].wait(); + t0.next(); + let t1 = v[1].wait(); + t1.next(); + let t2 = v[2].wait(); + t2.next(); + let t0 = v[0].wait(); + t0.next(); + //assert_eq!(v[2].current(), 1); } }