aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJake McGinty <me@jake.su>2018-02-06 00:49:00 +0000
committerJake McGinty <me@jake.su>2018-02-06 00:49:00 +0000
commit438c6244647d2e2d1cf3aac7851f767cd2a52637 (patch)
tree62f721581781fb15011ee2103b4f43dbef90d56c /src
parentfinish up basic utun ipv6 support (diff)
downloadwireguard-rs-438c6244647d2e2d1cf3aac7851f767cd2a52637.tar.xz
wireguard-rs-438c6244647d2e2d1cf3aac7851f767cd2a52637.zip
add sopium's AntiReplay struct and implement it for transport packets
Diffstat (limited to 'src')
-rw-r--r--src/anti_replay.rs184
-rw-r--r--src/interface/peer_server.rs20
-rw-r--r--src/main.rs2
-rw-r--r--src/protocol/peer.rs23
4 files changed, 207 insertions, 22 deletions
diff --git a/src/anti_replay.rs b/src/anti_replay.rs
new file mode 100644
index 0000000..e3803aa
--- /dev/null
+++ b/src/anti_replay.rs
@@ -0,0 +1,184 @@
+// Copyright 2017 Guanhao Yin <sopium@mysterious.site>
+
+// This file is part of WireGuard.rs.
+
+// WireGuard.rs is free software: you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+
+// WireGuard.rs is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+
+// You should have received a copy of the GNU General Public License
+// along with WireGuard.rs. If not, see <https://www.gnu.org/licenses/>.
+
+
+// This is RFC 6479.
+
+// Power of 2.
+const BITMAP_BITLEN: u64 = 2048;
+
+const SIZE_OF_INTEGER: u64 = 32;
+const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_INTEGER) as usize;
+const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1;
+// REDUNDANT_BIT_SHIFTS = log2(SIZE_OF_INTEGER).
+const REDUNDANT_BIT_SHIFTS: u64 = 5;
+const BITMAP_LOC_MASK: u64 = SIZE_OF_INTEGER - 1;
+/// Size of anti-replay window.
+pub const WINDOW_SIZE: u64 = BITMAP_BITLEN - SIZE_OF_INTEGER;
+
+pub struct AntiReplay {
+ bitmap: [u32; BITMAP_LEN],
+ last: u64,
+}
+
+impl Default for AntiReplay {
+ fn default() -> Self {
+ AntiReplay::new()
+ }
+}
+
+impl AntiReplay {
+ pub fn new() -> Self {
+ AntiReplay {
+ last: 0,
+ bitmap: [0; BITMAP_LEN],
+ }
+ }
+
+ /// Returns true if check is passed, i.e., not a replay or too old.
+ ///
+ /// Unlike RFC 6479, zero is allowed.
+ pub fn check(&self, seq: u64) -> bool {
+ // Larger is always good.
+ if seq > self.last {
+ return true;
+ }
+
+ if self.last - seq > WINDOW_SIZE {
+ return false;
+ }
+
+ let bit_location = seq & BITMAP_LOC_MASK;
+ let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK;
+
+ self.bitmap[index as usize] & (1 << bit_location) == 0
+ }
+
+ /// Should only be called if check returns true.
+ pub fn update(&mut self, seq: u64) {
+ debug_assert!(self.check(seq));
+
+ let index = seq >> REDUNDANT_BIT_SHIFTS;
+
+ if seq > self.last {
+ let index_cur = self.last >> REDUNDANT_BIT_SHIFTS;
+ let diff = index - index_cur;
+
+ if diff >= BITMAP_LEN as u64 {
+ self.bitmap = [0; BITMAP_LEN];
+ } else {
+ for i in 0..diff {
+ let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK;
+ self.bitmap[real_index as usize] = 0;
+ }
+ }
+
+ self.last = seq;
+ }
+
+ let index = index & BITMAP_INDEX_MASK;
+ let bit_location = seq & BITMAP_LOC_MASK;
+ self.bitmap[index as usize] |= 1 << bit_location;
+ }
+
+ pub fn check_and_update(&mut self, seq: u64) -> bool {
+ if self.check(seq) {
+ self.update(seq);
+ true
+ } else {
+ false
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn anti_replay() {
+ let mut ar = AntiReplay::new();
+
+ for i in 0..20000 {
+ assert!(ar.check_and_update(i));
+ }
+
+ for i in (0..20000).rev() {
+ assert!(!ar.check(i));
+ }
+
+ assert!(ar.check_and_update(65536));
+ for i in (65536 - WINDOW_SIZE)..65535 {
+ assert!(ar.check_and_update(i));
+ }
+ for i in (65536 - 10 * WINDOW_SIZE)..65535 {
+ assert!(!ar.check(i));
+ }
+
+ ar.check_and_update(66000);
+ for i in 65537..66000 {
+ assert!(ar.check_and_update(i));
+ }
+ for i in 65537..66000 {
+ assert!(!ar.check_and_update(i));
+ }
+
+ // Test max u64.
+ let next = u64::max_value();
+ assert!(ar.check_and_update(next));
+ assert!(!ar.check(next));
+ for i in (next - WINDOW_SIZE)..next {
+ assert!(ar.check_and_update(i));
+ }
+ for i in (next - 20 * WINDOW_SIZE)..next {
+ assert!(!ar.check(i));
+ }
+ }
+
+ #[bench]
+ fn bench_anti_replay_sequential(b: &mut ::test::Bencher) {
+ let mut ar = AntiReplay::new();
+ let mut seq = 0;
+
+ b.iter(|| {
+ assert!(ar.check_and_update(seq));
+ seq += 1;
+ });
+ }
+
+ #[bench]
+ fn bench_anti_replay_old(b: &mut ::test::Bencher) {
+ let mut ar = AntiReplay::new();
+ ar.check_and_update(12345);
+ ar.check_and_update(11234);
+
+ b.iter(|| {
+ assert!(!ar.check_and_update(11234));
+ });
+ }
+
+ #[bench]
+ fn bench_anti_replay_large_skip(b: &mut ::test::Bencher) {
+ let mut ar = AntiReplay::new();
+ let mut seq = 0;
+
+ b.iter(|| {
+ assert!(ar.check_and_update(seq));
+ seq += 30000;
+ });
+ }
+}
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs
index 6c904fe..0ca7f3d 100644
--- a/src/interface/peer_server.rs
+++ b/src/interface/peer_server.rs
@@ -210,29 +210,13 @@ impl PeerServer {
let our_index_received = LittleEndian::read_u32(&packet[4..]);
let nonce = LittleEndian::read_u64(&packet[8..]);
- let mut raw_packet = vec![0u8; 1500];
let lookup = state.index_map.get(&our_index_received);
if let Some(ref peer) = lookup {
let mut peer = peer.borrow_mut();
- peer.rx_bytes += packet.len() as u64;
-
- // TODO: map index not just to peer, but to specific session instead of guessing
- let res = {
- let noise = peer.current_noise().expect("current noise session");
- noise.set_receiving_nonce(nonce).unwrap();
- noise.read_message(&packet[16..], &mut raw_packet).map_err(|_| ())
- }.or_else(|_| {
- if let Some(noise) = peer.past_noise() {
- noise.set_receiving_nonce(nonce).unwrap();
- noise.read_message(&packet[16..], &mut raw_packet).map_err(|_| ())
- } else {
- Err(())
- }
- });
+ let res = peer.decrypt_transport_packet(our_index_received, nonce, &packet[16..]);
- if let Ok(payload_len) = res {
- raw_packet.truncate(payload_len);
+ if let Ok(raw_packet) = res {
trace_packet("received TRANSPORT: ", &raw_packet);
let utun_packet = match (raw_packet[0] & 0xf0) >> 4 {
4 => UtunPacket::Inet4(raw_packet),
diff --git a/src/main.rs b/src/main.rs
index 88371b4..2ebd052 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,5 @@
#![feature(ip_constructors)]
+#![feature(option_filter)]
#![allow(unused_imports)]
#[macro_use] extern crate error_chain;
@@ -32,6 +33,7 @@ mod error;
mod interface;
mod protocol;
mod types;
+mod anti_replay;
use std::path::PathBuf;
diff --git a/src/protocol/peer.rs b/src/protocol/peer.rs
index ed1e6de..d4b6b40 100644
--- a/src/protocol/peer.rs
+++ b/src/protocol/peer.rs
@@ -1,3 +1,4 @@
+use anti_replay::AntiReplay;
use byteorder::{ByteOrder, BigEndian, LittleEndian};
use blake2_rfc::blake2s::{Blake2s, blake2s};
use snow::{self, NoiseBuilder};
@@ -34,6 +35,7 @@ pub struct Session {
pub noise: snow::Session,
pub our_index: u32,
pub their_index: u32,
+ pub anti_replay: AntiReplay,
}
impl Session {
@@ -43,6 +45,7 @@ impl Session {
noise: session,
our_index: rand::thread_rng().gen::<u32>(),
their_index,
+ anti_replay: AntiReplay::default(),
}
}
@@ -51,6 +54,7 @@ impl Session {
noise: self.noise.into_transport_mode().unwrap(),
our_index: self.our_index,
their_index: self.their_index,
+ anti_replay: self.anti_replay,
}
}
}
@@ -61,6 +65,7 @@ impl From<snow::Session> for Session {
noise: session,
our_index: rand::thread_rng().gen::<u32>(),
their_index: 0,
+ anti_replay: AntiReplay::default(),
}
}
}
@@ -110,11 +115,21 @@ impl Peer {
Ok(())
}
- pub fn past_noise(&mut self) -> Option<&mut snow::Session> {
- if let Some(ref mut session) = self.sessions.past {
- Some(&mut session.noise)
+ pub fn decrypt_transport_packet(&mut self, our_index: u32, nonce: u64, packet: &[u8]) -> Result<Vec<u8>, ()> {
+ let mut raw_packet = vec![0u8; 1500];
+ self.rx_bytes += packet.len() as u64;
+
+ let session = self.sessions.current.as_mut().filter(|session| session.our_index == our_index)
+ .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index))
+ .ok_or_else(|| ())?;
+
+ if session.anti_replay.check_and_update(nonce) {
+ session.noise.set_receiving_nonce(nonce).unwrap();
+ let len = session.noise.read_message(packet, &mut raw_packet).map_err(|_| ())?;
+ raw_packet.truncate(len);
+ Ok(raw_packet)
} else {
- None
+ Err(())
}
}