aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJake McGinty <me@jake.su>2017-12-30 12:34:09 -0800
committerJake McGinty <me@jake.su>2017-12-30 12:34:09 -0800
commita2c20c5b3abcb34beae91dac5c63aece06d0c903 (patch)
tree58c7195be4489bfee9b39fb87ba2e6a975e53eba /src
parentadd license and readme (diff)
downloadwireguard-rs-a2c20c5b3abcb34beae91dac5c63aece06d0c903.tar.xz
wireguard-rs-a2c20c5b3abcb34beae91dac5c63aece06d0c903.zip
PeerServer custom future refactor
Diffstat (limited to 'src')
-rw-r--r--src/interface/mod.rs179
-rw-r--r--src/interface/peer_server.rs174
-rw-r--r--src/main.rs11
3 files changed, 217 insertions, 147 deletions
diff --git a/src/interface/mod.rs b/src/interface/mod.rs
index 8cbbfb3..943c396 100644
--- a/src/interface/mod.rs
+++ b/src/interface/mod.rs
@@ -1,6 +1,9 @@
mod config;
+mod peer_server;
use self::config::{ConfigurationServiceManager, UpdateEvent, Command, ConfigurationCodec};
+use self::peer_server::{PeerServer, PeerServerMessage};
+
use base64;
use hex;
use byteorder::{ByteOrder, BigEndian, LittleEndian};
@@ -25,32 +28,25 @@ use tokio_io::codec::{Framed, Encoder, Decoder};
use tokio_uds::{UnixListener};
use tokio_timer::{Interval, Timer};
-fn debug_packet(header: &str, packet: &[u8]) {
+
+pub fn debug_packet(header: &str, packet: &[u8]) {
let packet = Ipv4Packet::new(packet);
debug!("{} {:?}", header, packet);
}
-pub struct Interface {
- name: String,
- info: Rc<RefCell<InterfaceInfo>>,
- peers: Rc<RefCell<HashMap<[u8; 32], Rc<RefCell<Peer>>>>>,
- ids: Rc<RefCell<HashMap<u32, Rc<RefCell<Peer>>>>>,
-}
-
-struct VecUdpCodec;
-impl UdpCodec for VecUdpCodec {
- type In = (SocketAddr, Vec<u8>);
- type Out = (SocketAddr, Vec<u8>);
+pub type SharedPeer = Rc<RefCell<Peer>>;
+pub type SharedState = Rc<RefCell<State>>;
- fn decode(&mut self, src: &SocketAddr, buf: &[u8]) -> io::Result<Self::In> {
- Ok((*src, buf.to_vec()))
- }
+#[derive(Default)]
+pub struct State {
+ pubkey_map: HashMap<[u8; 32], SharedPeer>,
+ index_map: HashMap<u32, SharedPeer>,
+ interface_info: InterfaceInfo,
+}
- fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr {
- let (addr, mut data) = msg;
- buf.append(&mut data);
- addr
- }
+pub struct Interface {
+ name: String,
+ state: SharedState,
}
struct VecUtunCodec;
@@ -76,15 +72,14 @@ impl UtunCodec for VecUtunCodec {
impl Interface {
pub fn new(name: &str) -> Self {
- let info = Rc::new(RefCell::new(InterfaceInfo::default()));
- let peers = Rc::new(RefCell::new(HashMap::new()));
- let ids = Rc::new(RefCell::new(HashMap::new()));
- let _config_service = ConfigurationServiceManager::new(name);
+ let state = State {
+ pubkey_map: HashMap::new(),
+ index_map: HashMap::new(),
+ interface_info: InterfaceInfo::default(),
+ };
Interface {
name: name.to_owned(),
- info,
- peers,
- ids,
+ state: Rc::new(RefCell::new(state)),
}
}
@@ -92,109 +87,20 @@ impl Interface {
let mut core = Core::new().unwrap();
let (utun_tx, utun_rx) = unsync::mpsc::channel::<Vec<u8>>(1024);
- let udp_socket = UdpSocket::bind(&([0,0,0,0], 0).into(), &core.handle()).unwrap();
- let (tx, rx) = unsync::mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
- let (udp_writer, udp_reader) = udp_socket.framed(VecUdpCodec{}).split();
- let udp_read_fut = udp_reader.for_each({
- let ids_ref = self.ids.clone();
- let handle = core.handle();
- let tx = tx.clone();
- let interface_info = self.info.clone();
- move |(_socket_addr, packet)| {
- debug!("got a UDP packet of length {}, packet type {}", packet.len(), packet[0]);
- match packet[0] {
- 1 => {
- info!("got handshake initialization.");
- },
- 2 => {
- let their_index = LittleEndian::read_u32(&packet[4..]);
- let our_index = LittleEndian::read_u32(&packet[8..]);
- let mut ids = ids_ref.borrow_mut();
- let peer_ref = ids.get(&our_index).unwrap().clone();
- let mut peer = peer_ref.borrow_mut();
- peer.sessions.next.as_mut().unwrap().their_index = their_index;
- let payload_len = peer.next_noise().expect("pending noise session")
- .read_message(&packet[12..60], &mut []).unwrap();
- assert!(payload_len == 0);
- peer.ratchet_session().unwrap();
- info!("got handshake response, ratcheted session.");
- let tx = tx.clone();
-
- let interface_info = interface_info.borrow();
- let noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap())
- .local_private_key(&interface_info.private_key.expect("no private key!"))
- .remote_public_key(&peer.info.pub_key)
- .prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes())
- .psk(2, &peer.info.psk.expect("no psk!"))
- .build_initiator().unwrap();
- peer.set_next_session(noise.into());
- let _ = ids.insert(peer.our_next_index().unwrap(), peer_ref.clone());
-
- let init_packet = peer.get_handshake_packet();
- let endpoint = peer.info.endpoint.unwrap().clone();
-
- let timer = Timer::default();
- let sleep = timer.sleep(Duration::from_secs(120));
- let boop = sleep.and_then({
- let handle = handle.clone();
- let peer_ref = peer_ref.clone();
- let interface_info = interface_info.clone();
- move |_| {
- info!("sending rekey!");
- handle.spawn(tx.clone().send((endpoint, init_packet))
- .map(|_| ())
- .map_err(|_| ()));
- Ok(())
- }
- }).map_err(|_|());
- handle.spawn(boop);
- },
- 4 => {
- let our_index_received = LittleEndian::read_u32(&packet[4..]);
- let nonce = LittleEndian::read_u64(&packet[8..]);
-
- let mut raw_packet = [0u8; 1500];
- let ids = ids_ref.borrow();
- let lookup = ids.get(&our_index_received);
- if let Some(ref peer) = lookup {
- let mut peer = peer.borrow_mut();
- // info!("retrieved peer with pubkey {}", base64::encode(&peer.pubkey));
- // info!("ok going to try to decrypt");
-
- peer.rx_bytes += packet.len();
- let noise = peer.current_noise().expect("current noise session");
- noise.set_receiving_nonce(nonce).unwrap();
- let payload_len = noise.read_message(&packet[16..], &mut raw_packet).unwrap();
- debug_packet("received TRANSPORT: ", &raw_packet[..payload_len]);
- handle.spawn(utun_tx.clone().send(raw_packet[..payload_len].to_owned())
- .map(|_| ())
- .map_err(|_| ()));
- }
- },
- _ => unimplemented!()
- }
- Ok(())
- }
- }).map_err(|_| ());
-
- let udp_write_fut = udp_writer.sink_map_err(|_| ()).send_all(
- rx.map(|(addr, packet)| {
- debug!("sending encrypted UDP packet");
- (addr, packet)
- }).map_err(|_| ())).map_err(|_| ());
+ let peer_server = PeerServer::bind(core.handle(), self.state.clone(), utun_tx.clone());
let utun_stream = UtunStream::connect(&self.name, &core.handle()).unwrap().framed(VecUtunCodec{});
let (utun_writer, utun_reader) = utun_stream.split();
let utun_fut = utun_reader.for_each({
- let ids = self.ids.clone();
+ let state = self.state.clone();
let utun_handle = core.handle();
- let udp_tx = tx.clone();
+ let udp_tx = peer_server.udp_tx();
move |packet| {
debug_packet("received UTUN packet: ", &packet);
+ let state = state.borrow();
let mut ping_packet = [0u8; 1500];
- let ids = ids.borrow();
- let (_key, peer) = ids.iter().next().unwrap(); // TODO destination IP peer lookup
+ let (_key, peer) = state.pubkey_map.iter().next().unwrap(); // TODO destination IP peer lookup
let mut peer = peer.borrow_mut();
ping_packet[0] = 4;
let their_index = peer.their_current_index().expect("no current index for them");
@@ -223,8 +129,7 @@ impl Interface {
let h = handle.clone();
let config_server = listener.incoming().for_each({
let config_tx = config_tx.clone();
- let info = self.info.clone();
- let peers = self.peers.clone();
+ let state = self.state.clone();
move |(stream, _)| {
let (sink, stream) = stream.framed(ConfigurationCodec {}).split();
debug!("UnixServer connection.");
@@ -232,17 +137,17 @@ impl Interface {
let handle = h.clone();
let responses = stream.and_then({
let config_tx = config_tx.clone();
- let info = info.clone();
- let peers = peers.clone();
+ let state = state.clone();
move |command| {
+ let state = state.borrow();
match command {
Command::Set(_version, items) => {
config_tx.clone().send_all(stream::iter_ok(items)).wait().unwrap();
future::ok("errno=0\nerrno=0\n\n".to_string())
},
Command::Get(_version) => {
- let info = info.borrow();
- let peers = peers.borrow();
+ let info = &state.interface_info;
+ let peers = &state.pubkey_map;
let mut s = String::new();
if let Some(private_key) = info.private_key {
s.push_str(&format!("private_key={}\n", hex::encode(private_key)));
@@ -266,26 +171,24 @@ impl Interface {
}).map_err(|_| ());
let config_fut = config_rx.for_each({
- let tx = tx.clone();
+ let tx = peer_server.udp_tx().clone();
let handle = handle.clone();
+ let state = self.state.clone();
move |event| {
- let interface_info = self.info.clone();
+ let mut state = state.borrow_mut();
match event {
UpdateEvent::PrivateKey(private_key) => {
- let mut interface_info = interface_info.borrow_mut();
- interface_info.private_key = Some(private_key);
+ state.interface_info.private_key = Some(private_key);
debug!("set new private key");
},
UpdateEvent::ListenPort(port) => {
- let mut interface_info = interface_info.borrow_mut();
- interface_info.listen_port = Some(port);
+ state.interface_info.listen_port = Some(port);
debug!("set new listen port");
},
UpdateEvent::UpdatePeer(info) => {
info!("added new peer: {}", info);
- let interface_info = interface_info.borrow();
let mut noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap())
- .local_private_key(&interface_info.private_key.expect("no private key!"))
+ .local_private_key(&state.interface_info.private_key.expect("no private key!"))
.remote_public_key(&info.pub_key)
.prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes())
.psk(2, &info.psk.expect("no psk!"))
@@ -298,8 +201,8 @@ impl Interface {
let our_index = peer.our_next_index().unwrap();
let peer = Rc::new(RefCell::new(peer));
- let _ = self.ids.borrow_mut().insert(our_index, peer.clone());
- let _ = self.peers.borrow_mut().insert(info.pub_key, peer);
+ let _ = state.index_map.insert(our_index, peer.clone());
+ let _ = state.pubkey_map.insert(info.pub_key, peer);
handle.spawn(tx.clone().send((info.endpoint.unwrap(), init_packet))
.map(|_| ())
@@ -312,6 +215,6 @@ impl Interface {
}
}).map_err(|_| ());
- core.run(utun_fut.join(utun_write_fut.join(udp_read_fut.join(udp_write_fut.join(config_fut.join(config_server)))))).unwrap();
+ core.run(peer_server.join(utun_fut.join(utun_write_fut.join(config_fut.join(config_server))))).unwrap();
}
}
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs
new file mode 100644
index 0000000..b812677
--- /dev/null
+++ b/src/interface/peer_server.rs
@@ -0,0 +1,174 @@
+use super::{SharedState, SharedPeer, debug_packet};
+
+use std::io;
+use std::net::SocketAddr;
+use std::time::Duration;
+
+use byteorder::{ByteOrder, BigEndian, LittleEndian};
+use futures::{Async, Future, Stream, Sink, Poll, future, unsync, sync, stream};
+use tokio_core::net::{UdpSocket, UdpCodec, UdpFramed};
+use tokio_core::reactor::Handle;
+use tokio_io::codec::Framed;
+use tokio_timer::{Interval, Timer};
+use snow::NoiseBuilder;
+
+
+pub type PeerServerMessage = (SocketAddr, Vec<u8>);
+struct VecUdpCodec;
+impl UdpCodec for VecUdpCodec {
+ type In = PeerServerMessage;
+ type Out = PeerServerMessage;
+
+ fn decode(&mut self, src: &SocketAddr, buf: &[u8]) -> io::Result<Self::In> {
+ Ok((*src, buf.to_vec()))
+ }
+
+ fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr {
+ let (addr, mut data) = msg;
+ buf.append(&mut data);
+ addr
+ }
+}
+
+
+pub struct PeerServer {
+ handle: Handle,
+ shared_state: SharedState,
+ udp_stream: stream::SplitStream<UdpFramed<VecUdpCodec>>,
+ rx: unsync::mpsc::Receiver<Vec<u8>>,
+ udp_tx: unsync::mpsc::Sender<(SocketAddr, Vec<u8>)>,
+ tunnel_tx: unsync::mpsc::Sender<Vec<u8>>,
+ pub tx: unsync::mpsc::Sender<Vec<u8>>,
+}
+
+impl PeerServer {
+ pub fn bind(handle: Handle, shared_state: SharedState, tunnel_tx: unsync::mpsc::Sender<Vec<u8>>) -> Self {
+ let socket = UdpSocket::bind(&([0,0,0,0], 0).into(), &handle.clone()).unwrap();
+ let (udp_sink, udp_stream) = socket.framed(VecUdpCodec{}).split();
+ let (udp_tx, udp_rx) = unsync::mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
+ let (tx, rx) = unsync::mpsc::channel::<Vec<u8>>(1024);
+
+ let udp_write_passthrough = udp_sink.sink_map_err(|_| ()).send_all(
+ udp_rx.map(|(addr, packet)| {
+ debug_packet("sending UDP: ", &packet);
+ (addr, packet)
+ }).map_err(|_| ()))
+ .then(|_| Ok(()));
+ handle.spawn(udp_write_passthrough);
+
+ PeerServer {
+ handle, shared_state, udp_stream, udp_tx, tunnel_tx, tx, rx
+ }
+ }
+
+ pub fn tx(&self) -> unsync::mpsc::Sender<Vec<u8>> {
+ self.tx.clone()
+ }
+
+ pub fn udp_tx(&self) -> unsync::mpsc::Sender<(SocketAddr, Vec<u8>)> {
+ self.udp_tx.clone()
+ }
+
+ fn handle_incoming_packet(&mut self, addr: SocketAddr, packet: Vec<u8>) {
+ debug!("got a UDP packet of length {}, packet type {}", packet.len(), packet[0]);
+ let mut state = self.shared_state.borrow_mut();
+ match packet[0] {
+ 1 => {
+ info!("got handshake initialization.");
+ },
+ 2 => {
+ let their_index = LittleEndian::read_u32(&packet[4..]);
+ let our_index = LittleEndian::read_u32(&packet[8..]);
+ let peer_ref = state.index_map.get(&our_index).unwrap().clone();
+ let mut peer = peer_ref.borrow_mut();
+ peer.sessions.next.as_mut().unwrap().their_index = their_index;
+ let payload_len = peer.next_noise().expect("pending noise session")
+ .read_message(&packet[12..60], &mut []).unwrap();
+ assert!(payload_len == 0);
+ peer.ratchet_session().unwrap();
+ info!("got handshake response, ratcheted session.");
+
+ let noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap())
+ .local_private_key(&state.interface_info.private_key.expect("no private key!"))
+ .remote_public_key(&peer.info.pub_key)
+ .prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes())
+ .psk(2, &peer.info.psk.expect("no psk!"))
+ .build_initiator().unwrap();
+ peer.set_next_session(noise.into());
+
+ let _ = state.index_map.insert(peer.our_next_index().unwrap(), peer_ref.clone());
+
+ let init_packet = peer.get_handshake_packet();
+ let endpoint = peer.info.endpoint.unwrap().clone();
+
+ let timer = Timer::default();
+ let sleep = timer.sleep(Duration::from_secs(120));
+ let boop = sleep.and_then({
+ let handle = self.handle.clone();
+ let tx = self.udp_tx.clone();
+ let peer_ref = peer_ref.clone();
+ move |_| {
+ info!("sending rekey!");
+ handle.spawn(tx.clone().send((endpoint, init_packet))
+ .map(|_| ())
+ .map_err(|_| ()));
+ Ok(())
+ }
+ }).map_err(|_|());
+ self.handle.spawn(boop);
+ },
+ 4 => {
+ let our_index_received = LittleEndian::read_u32(&packet[4..]);
+ let nonce = LittleEndian::read_u64(&packet[8..]);
+
+ let mut raw_packet = [0u8; 1500];
+ let lookup = state.index_map.get(&our_index_received);
+ if let Some(ref peer) = lookup {
+ let mut peer = peer.borrow_mut();
+
+ peer.rx_bytes += packet.len();
+ let noise = peer.current_noise().expect("current noise session");
+ noise.set_receiving_nonce(nonce).unwrap();
+ let payload_len = noise.read_message(&packet[16..], &mut raw_packet).unwrap();
+ debug_packet("received TRANSPORT: ", &raw_packet[..payload_len]);
+ self.handle.spawn(self.tunnel_tx.clone().send(raw_packet[..payload_len].to_owned())
+ .map(|_| ())
+ .map_err(|_| ()));
+ }
+ },
+ _ => unimplemented!()
+ }
+ }
+
+ fn handle_outgoing_packet(&mut self, packet: Vec<u8>) {
+ debug!("handle_outgoing_packet()");
+
+ }
+}
+
+impl Future for PeerServer {
+ type Item = ();
+ type Error = ();
+
+ fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+ // Handle UDP packets from the outside world
+ loop {
+ match self.udp_stream.poll() {
+ Ok(Async::Ready(Some((addr, packet)))) => self.handle_incoming_packet(addr, packet),
+ Ok(Async::NotReady) => break,
+ Ok(Async::Ready(None)) | Err(_) => return Err(()),
+ }
+ }
+
+ // Handle packets coming from the local tunnel
+ loop {
+ match self.rx.poll() {
+ Ok(Async::Ready(Some(packet))) => self.handle_outgoing_packet(packet),
+ Ok(Async::NotReady) => break,
+ Ok(Async::Ready(None)) | Err(_) => return Err(()),
+ }
+ }
+
+ Ok(Async::NotReady)
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 0f3f9e4..240e6e6 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,25 +1,19 @@
#![allow(unused_imports)]
#[macro_use] extern crate log;
-extern crate env_logger;
-
#[macro_use] extern crate structopt_derive;
#[macro_use] extern crate error_chain;
+#[macro_use] extern crate futures;
-
+extern crate env_logger;
extern crate daemonize;
extern crate rand;
extern crate nix;
extern crate structopt;
-extern crate mio;
-
extern crate bytes;
-extern crate futures;
extern crate tokio_core;
extern crate tokio_io;
-extern crate tokio_proto;
-extern crate tokio_service;
extern crate tokio_uds;
extern crate tokio_utun;
extern crate tokio_timer;
@@ -73,7 +67,6 @@ fn main() {
}
Interface::new(&opt.interface).start();
-// WireGuard::start(interface_name).expect("failed to start WireGuard interface");
}
fn daemonize() -> Result<()> {