diff options
author | Jake McGinty <me@jake.su> | 2018-02-06 00:49:00 +0000 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2018-02-06 00:49:00 +0000 |
commit | 438c6244647d2e2d1cf3aac7851f767cd2a52637 (patch) | |
tree | 62f721581781fb15011ee2103b4f43dbef90d56c /src | |
parent | finish up basic utun ipv6 support (diff) | |
download | wireguard-rs-438c6244647d2e2d1cf3aac7851f767cd2a52637.tar.xz wireguard-rs-438c6244647d2e2d1cf3aac7851f767cd2a52637.zip |
add sopium's AntiReplay struct and implement it for transport packets
Diffstat (limited to 'src')
-rw-r--r-- | src/anti_replay.rs | 184 | ||||
-rw-r--r-- | src/interface/peer_server.rs | 20 | ||||
-rw-r--r-- | src/main.rs | 2 | ||||
-rw-r--r-- | src/protocol/peer.rs | 23 |
4 files changed, 207 insertions, 22 deletions
diff --git a/src/anti_replay.rs b/src/anti_replay.rs new file mode 100644 index 0000000..e3803aa --- /dev/null +++ b/src/anti_replay.rs @@ -0,0 +1,184 @@ +// Copyright 2017 Guanhao Yin <sopium@mysterious.site> + +// This file is part of WireGuard.rs. + +// WireGuard.rs is free software: you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. + +// WireGuard.rs is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with WireGuard.rs. If not, see <https://www.gnu.org/licenses/>. + + +// This is RFC 6479. + +// Power of 2. +const BITMAP_BITLEN: u64 = 2048; + +const SIZE_OF_INTEGER: u64 = 32; +const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_INTEGER) as usize; +const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1; +// REDUNDANT_BIT_SHIFTS = log2(SIZE_OF_INTEGER). +const REDUNDANT_BIT_SHIFTS: u64 = 5; +const BITMAP_LOC_MASK: u64 = SIZE_OF_INTEGER - 1; +/// Size of anti-replay window. +pub const WINDOW_SIZE: u64 = BITMAP_BITLEN - SIZE_OF_INTEGER; + +pub struct AntiReplay { + bitmap: [u32; BITMAP_LEN], + last: u64, +} + +impl Default for AntiReplay { + fn default() -> Self { + AntiReplay::new() + } +} + +impl AntiReplay { + pub fn new() -> Self { + AntiReplay { + last: 0, + bitmap: [0; BITMAP_LEN], + } + } + + /// Returns true if check is passed, i.e., not a replay or too old. + /// + /// Unlike RFC 6479, zero is allowed. + pub fn check(&self, seq: u64) -> bool { + // Larger is always good. + if seq > self.last { + return true; + } + + if self.last - seq > WINDOW_SIZE { + return false; + } + + let bit_location = seq & BITMAP_LOC_MASK; + let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK; + + self.bitmap[index as usize] & (1 << bit_location) == 0 + } + + /// Should only be called if check returns true. + pub fn update(&mut self, seq: u64) { + debug_assert!(self.check(seq)); + + let index = seq >> REDUNDANT_BIT_SHIFTS; + + if seq > self.last { + let index_cur = self.last >> REDUNDANT_BIT_SHIFTS; + let diff = index - index_cur; + + if diff >= BITMAP_LEN as u64 { + self.bitmap = [0; BITMAP_LEN]; + } else { + for i in 0..diff { + let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK; + self.bitmap[real_index as usize] = 0; + } + } + + self.last = seq; + } + + let index = index & BITMAP_INDEX_MASK; + let bit_location = seq & BITMAP_LOC_MASK; + self.bitmap[index as usize] |= 1 << bit_location; + } + + pub fn check_and_update(&mut self, seq: u64) -> bool { + if self.check(seq) { + self.update(seq); + true + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn anti_replay() { + let mut ar = AntiReplay::new(); + + for i in 0..20000 { + assert!(ar.check_and_update(i)); + } + + for i in (0..20000).rev() { + assert!(!ar.check(i)); + } + + assert!(ar.check_and_update(65536)); + for i in (65536 - WINDOW_SIZE)..65535 { + assert!(ar.check_and_update(i)); + } + for i in (65536 - 10 * WINDOW_SIZE)..65535 { + assert!(!ar.check(i)); + } + + ar.check_and_update(66000); + for i in 65537..66000 { + assert!(ar.check_and_update(i)); + } + for i in 65537..66000 { + assert!(!ar.check_and_update(i)); + } + + // Test max u64. + let next = u64::max_value(); + assert!(ar.check_and_update(next)); + assert!(!ar.check(next)); + for i in (next - WINDOW_SIZE)..next { + assert!(ar.check_and_update(i)); + } + for i in (next - 20 * WINDOW_SIZE)..next { + assert!(!ar.check(i)); + } + } + + #[bench] + fn bench_anti_replay_sequential(b: &mut ::test::Bencher) { + let mut ar = AntiReplay::new(); + let mut seq = 0; + + b.iter(|| { + assert!(ar.check_and_update(seq)); + seq += 1; + }); + } + + #[bench] + fn bench_anti_replay_old(b: &mut ::test::Bencher) { + let mut ar = AntiReplay::new(); + ar.check_and_update(12345); + ar.check_and_update(11234); + + b.iter(|| { + assert!(!ar.check_and_update(11234)); + }); + } + + #[bench] + fn bench_anti_replay_large_skip(b: &mut ::test::Bencher) { + let mut ar = AntiReplay::new(); + let mut seq = 0; + + b.iter(|| { + assert!(ar.check_and_update(seq)); + seq += 30000; + }); + } +} diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index 6c904fe..0ca7f3d 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -210,29 +210,13 @@ impl PeerServer { let our_index_received = LittleEndian::read_u32(&packet[4..]); let nonce = LittleEndian::read_u64(&packet[8..]); - let mut raw_packet = vec![0u8; 1500]; let lookup = state.index_map.get(&our_index_received); if let Some(ref peer) = lookup { let mut peer = peer.borrow_mut(); - peer.rx_bytes += packet.len() as u64; - - // TODO: map index not just to peer, but to specific session instead of guessing - let res = { - let noise = peer.current_noise().expect("current noise session"); - noise.set_receiving_nonce(nonce).unwrap(); - noise.read_message(&packet[16..], &mut raw_packet).map_err(|_| ()) - }.or_else(|_| { - if let Some(noise) = peer.past_noise() { - noise.set_receiving_nonce(nonce).unwrap(); - noise.read_message(&packet[16..], &mut raw_packet).map_err(|_| ()) - } else { - Err(()) - } - }); + let res = peer.decrypt_transport_packet(our_index_received, nonce, &packet[16..]); - if let Ok(payload_len) = res { - raw_packet.truncate(payload_len); + if let Ok(raw_packet) = res { trace_packet("received TRANSPORT: ", &raw_packet); let utun_packet = match (raw_packet[0] & 0xf0) >> 4 { 4 => UtunPacket::Inet4(raw_packet), diff --git a/src/main.rs b/src/main.rs index 88371b4..2ebd052 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ #![feature(ip_constructors)] +#![feature(option_filter)] #![allow(unused_imports)] #[macro_use] extern crate error_chain; @@ -32,6 +33,7 @@ mod error; mod interface; mod protocol; mod types; +mod anti_replay; use std::path::PathBuf; diff --git a/src/protocol/peer.rs b/src/protocol/peer.rs index ed1e6de..d4b6b40 100644 --- a/src/protocol/peer.rs +++ b/src/protocol/peer.rs @@ -1,3 +1,4 @@ +use anti_replay::AntiReplay; use byteorder::{ByteOrder, BigEndian, LittleEndian}; use blake2_rfc::blake2s::{Blake2s, blake2s}; use snow::{self, NoiseBuilder}; @@ -34,6 +35,7 @@ pub struct Session { pub noise: snow::Session, pub our_index: u32, pub their_index: u32, + pub anti_replay: AntiReplay, } impl Session { @@ -43,6 +45,7 @@ impl Session { noise: session, our_index: rand::thread_rng().gen::<u32>(), their_index, + anti_replay: AntiReplay::default(), } } @@ -51,6 +54,7 @@ impl Session { noise: self.noise.into_transport_mode().unwrap(), our_index: self.our_index, their_index: self.their_index, + anti_replay: self.anti_replay, } } } @@ -61,6 +65,7 @@ impl From<snow::Session> for Session { noise: session, our_index: rand::thread_rng().gen::<u32>(), their_index: 0, + anti_replay: AntiReplay::default(), } } } @@ -110,11 +115,21 @@ impl Peer { Ok(()) } - pub fn past_noise(&mut self) -> Option<&mut snow::Session> { - if let Some(ref mut session) = self.sessions.past { - Some(&mut session.noise) + pub fn decrypt_transport_packet(&mut self, our_index: u32, nonce: u64, packet: &[u8]) -> Result<Vec<u8>, ()> { + let mut raw_packet = vec![0u8; 1500]; + self.rx_bytes += packet.len() as u64; + + let session = self.sessions.current.as_mut().filter(|session| session.our_index == our_index) + .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index)) + .ok_or_else(|| ())?; + + if session.anti_replay.check_and_update(nonce) { + session.noise.set_receiving_nonce(nonce).unwrap(); + let len = session.noise.read_message(packet, &mut raw_packet).map_err(|_| ())?; + raw_packet.truncate(len); + Ok(raw_packet) } else { - None + Err(()) } } |