diff options
author | Jake McGinty <me@jake.su> | 2017-12-30 12:34:09 -0800 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2017-12-30 12:34:09 -0800 |
commit | a2c20c5b3abcb34beae91dac5c63aece06d0c903 (patch) | |
tree | 58c7195be4489bfee9b39fb87ba2e6a975e53eba /src | |
parent | add license and readme (diff) | |
download | wireguard-rs-a2c20c5b3abcb34beae91dac5c63aece06d0c903.tar.xz wireguard-rs-a2c20c5b3abcb34beae91dac5c63aece06d0c903.zip |
PeerServer custom future refactor
Diffstat (limited to 'src')
-rw-r--r-- | src/interface/mod.rs | 179 | ||||
-rw-r--r-- | src/interface/peer_server.rs | 174 | ||||
-rw-r--r-- | src/main.rs | 11 |
3 files changed, 217 insertions, 147 deletions
diff --git a/src/interface/mod.rs b/src/interface/mod.rs index 8cbbfb3..943c396 100644 --- a/src/interface/mod.rs +++ b/src/interface/mod.rs @@ -1,6 +1,9 @@ mod config; +mod peer_server; use self::config::{ConfigurationServiceManager, UpdateEvent, Command, ConfigurationCodec}; +use self::peer_server::{PeerServer, PeerServerMessage}; + use base64; use hex; use byteorder::{ByteOrder, BigEndian, LittleEndian}; @@ -25,32 +28,25 @@ use tokio_io::codec::{Framed, Encoder, Decoder}; use tokio_uds::{UnixListener}; use tokio_timer::{Interval, Timer}; -fn debug_packet(header: &str, packet: &[u8]) { + +pub fn debug_packet(header: &str, packet: &[u8]) { let packet = Ipv4Packet::new(packet); debug!("{} {:?}", header, packet); } -pub struct Interface { - name: String, - info: Rc<RefCell<InterfaceInfo>>, - peers: Rc<RefCell<HashMap<[u8; 32], Rc<RefCell<Peer>>>>>, - ids: Rc<RefCell<HashMap<u32, Rc<RefCell<Peer>>>>>, -} - -struct VecUdpCodec; -impl UdpCodec for VecUdpCodec { - type In = (SocketAddr, Vec<u8>); - type Out = (SocketAddr, Vec<u8>); +pub type SharedPeer = Rc<RefCell<Peer>>; +pub type SharedState = Rc<RefCell<State>>; - fn decode(&mut self, src: &SocketAddr, buf: &[u8]) -> io::Result<Self::In> { - Ok((*src, buf.to_vec())) - } +#[derive(Default)] +pub struct State { + pubkey_map: HashMap<[u8; 32], SharedPeer>, + index_map: HashMap<u32, SharedPeer>, + interface_info: InterfaceInfo, +} - fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr { - let (addr, mut data) = msg; - buf.append(&mut data); - addr - } +pub struct Interface { + name: String, + state: SharedState, } struct VecUtunCodec; @@ -76,15 +72,14 @@ impl UtunCodec for VecUtunCodec { impl Interface { pub fn new(name: &str) -> Self { - let info = Rc::new(RefCell::new(InterfaceInfo::default())); - let peers = Rc::new(RefCell::new(HashMap::new())); - let ids = Rc::new(RefCell::new(HashMap::new())); - let _config_service = ConfigurationServiceManager::new(name); + let state = State { + pubkey_map: HashMap::new(), + index_map: HashMap::new(), + interface_info: InterfaceInfo::default(), + }; Interface { name: name.to_owned(), - info, - peers, - ids, + state: Rc::new(RefCell::new(state)), } } @@ -92,109 +87,20 @@ impl Interface { let mut core = Core::new().unwrap(); let (utun_tx, utun_rx) = unsync::mpsc::channel::<Vec<u8>>(1024); - let udp_socket = UdpSocket::bind(&([0,0,0,0], 0).into(), &core.handle()).unwrap(); - let (tx, rx) = unsync::mpsc::channel::<(SocketAddr, Vec<u8>)>(1024); - let (udp_writer, udp_reader) = udp_socket.framed(VecUdpCodec{}).split(); - let udp_read_fut = udp_reader.for_each({ - let ids_ref = self.ids.clone(); - let handle = core.handle(); - let tx = tx.clone(); - let interface_info = self.info.clone(); - move |(_socket_addr, packet)| { - debug!("got a UDP packet of length {}, packet type {}", packet.len(), packet[0]); - match packet[0] { - 1 => { - info!("got handshake initialization."); - }, - 2 => { - let their_index = LittleEndian::read_u32(&packet[4..]); - let our_index = LittleEndian::read_u32(&packet[8..]); - let mut ids = ids_ref.borrow_mut(); - let peer_ref = ids.get(&our_index).unwrap().clone(); - let mut peer = peer_ref.borrow_mut(); - peer.sessions.next.as_mut().unwrap().their_index = their_index; - let payload_len = peer.next_noise().expect("pending noise session") - .read_message(&packet[12..60], &mut []).unwrap(); - assert!(payload_len == 0); - peer.ratchet_session().unwrap(); - info!("got handshake response, ratcheted session."); - let tx = tx.clone(); - - let interface_info = interface_info.borrow(); - let noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap()) - .local_private_key(&interface_info.private_key.expect("no private key!")) - .remote_public_key(&peer.info.pub_key) - .prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes()) - .psk(2, &peer.info.psk.expect("no psk!")) - .build_initiator().unwrap(); - peer.set_next_session(noise.into()); - let _ = ids.insert(peer.our_next_index().unwrap(), peer_ref.clone()); - - let init_packet = peer.get_handshake_packet(); - let endpoint = peer.info.endpoint.unwrap().clone(); - - let timer = Timer::default(); - let sleep = timer.sleep(Duration::from_secs(120)); - let boop = sleep.and_then({ - let handle = handle.clone(); - let peer_ref = peer_ref.clone(); - let interface_info = interface_info.clone(); - move |_| { - info!("sending rekey!"); - handle.spawn(tx.clone().send((endpoint, init_packet)) - .map(|_| ()) - .map_err(|_| ())); - Ok(()) - } - }).map_err(|_|()); - handle.spawn(boop); - }, - 4 => { - let our_index_received = LittleEndian::read_u32(&packet[4..]); - let nonce = LittleEndian::read_u64(&packet[8..]); - - let mut raw_packet = [0u8; 1500]; - let ids = ids_ref.borrow(); - let lookup = ids.get(&our_index_received); - if let Some(ref peer) = lookup { - let mut peer = peer.borrow_mut(); - // info!("retrieved peer with pubkey {}", base64::encode(&peer.pubkey)); - // info!("ok going to try to decrypt"); - - peer.rx_bytes += packet.len(); - let noise = peer.current_noise().expect("current noise session"); - noise.set_receiving_nonce(nonce).unwrap(); - let payload_len = noise.read_message(&packet[16..], &mut raw_packet).unwrap(); - debug_packet("received TRANSPORT: ", &raw_packet[..payload_len]); - handle.spawn(utun_tx.clone().send(raw_packet[..payload_len].to_owned()) - .map(|_| ()) - .map_err(|_| ())); - } - }, - _ => unimplemented!() - } - Ok(()) - } - }).map_err(|_| ()); - - let udp_write_fut = udp_writer.sink_map_err(|_| ()).send_all( - rx.map(|(addr, packet)| { - debug!("sending encrypted UDP packet"); - (addr, packet) - }).map_err(|_| ())).map_err(|_| ()); + let peer_server = PeerServer::bind(core.handle(), self.state.clone(), utun_tx.clone()); let utun_stream = UtunStream::connect(&self.name, &core.handle()).unwrap().framed(VecUtunCodec{}); let (utun_writer, utun_reader) = utun_stream.split(); let utun_fut = utun_reader.for_each({ - let ids = self.ids.clone(); + let state = self.state.clone(); let utun_handle = core.handle(); - let udp_tx = tx.clone(); + let udp_tx = peer_server.udp_tx(); move |packet| { debug_packet("received UTUN packet: ", &packet); + let state = state.borrow(); let mut ping_packet = [0u8; 1500]; - let ids = ids.borrow(); - let (_key, peer) = ids.iter().next().unwrap(); // TODO destination IP peer lookup + let (_key, peer) = state.pubkey_map.iter().next().unwrap(); // TODO destination IP peer lookup let mut peer = peer.borrow_mut(); ping_packet[0] = 4; let their_index = peer.their_current_index().expect("no current index for them"); @@ -223,8 +129,7 @@ impl Interface { let h = handle.clone(); let config_server = listener.incoming().for_each({ let config_tx = config_tx.clone(); - let info = self.info.clone(); - let peers = self.peers.clone(); + let state = self.state.clone(); move |(stream, _)| { let (sink, stream) = stream.framed(ConfigurationCodec {}).split(); debug!("UnixServer connection."); @@ -232,17 +137,17 @@ impl Interface { let handle = h.clone(); let responses = stream.and_then({ let config_tx = config_tx.clone(); - let info = info.clone(); - let peers = peers.clone(); + let state = state.clone(); move |command| { + let state = state.borrow(); match command { Command::Set(_version, items) => { config_tx.clone().send_all(stream::iter_ok(items)).wait().unwrap(); future::ok("errno=0\nerrno=0\n\n".to_string()) }, Command::Get(_version) => { - let info = info.borrow(); - let peers = peers.borrow(); + let info = &state.interface_info; + let peers = &state.pubkey_map; let mut s = String::new(); if let Some(private_key) = info.private_key { s.push_str(&format!("private_key={}\n", hex::encode(private_key))); @@ -266,26 +171,24 @@ impl Interface { }).map_err(|_| ()); let config_fut = config_rx.for_each({ - let tx = tx.clone(); + let tx = peer_server.udp_tx().clone(); let handle = handle.clone(); + let state = self.state.clone(); move |event| { - let interface_info = self.info.clone(); + let mut state = state.borrow_mut(); match event { UpdateEvent::PrivateKey(private_key) => { - let mut interface_info = interface_info.borrow_mut(); - interface_info.private_key = Some(private_key); + state.interface_info.private_key = Some(private_key); debug!("set new private key"); }, UpdateEvent::ListenPort(port) => { - let mut interface_info = interface_info.borrow_mut(); - interface_info.listen_port = Some(port); + state.interface_info.listen_port = Some(port); debug!("set new listen port"); }, UpdateEvent::UpdatePeer(info) => { info!("added new peer: {}", info); - let interface_info = interface_info.borrow(); let mut noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap()) - .local_private_key(&interface_info.private_key.expect("no private key!")) + .local_private_key(&state.interface_info.private_key.expect("no private key!")) .remote_public_key(&info.pub_key) .prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes()) .psk(2, &info.psk.expect("no psk!")) @@ -298,8 +201,8 @@ impl Interface { let our_index = peer.our_next_index().unwrap(); let peer = Rc::new(RefCell::new(peer)); - let _ = self.ids.borrow_mut().insert(our_index, peer.clone()); - let _ = self.peers.borrow_mut().insert(info.pub_key, peer); + let _ = state.index_map.insert(our_index, peer.clone()); + let _ = state.pubkey_map.insert(info.pub_key, peer); handle.spawn(tx.clone().send((info.endpoint.unwrap(), init_packet)) .map(|_| ()) @@ -312,6 +215,6 @@ impl Interface { } }).map_err(|_| ()); - core.run(utun_fut.join(utun_write_fut.join(udp_read_fut.join(udp_write_fut.join(config_fut.join(config_server)))))).unwrap(); + core.run(peer_server.join(utun_fut.join(utun_write_fut.join(config_fut.join(config_server))))).unwrap(); } } diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs new file mode 100644 index 0000000..b812677 --- /dev/null +++ b/src/interface/peer_server.rs @@ -0,0 +1,174 @@ +use super::{SharedState, SharedPeer, debug_packet}; + +use std::io; +use std::net::SocketAddr; +use std::time::Duration; + +use byteorder::{ByteOrder, BigEndian, LittleEndian}; +use futures::{Async, Future, Stream, Sink, Poll, future, unsync, sync, stream}; +use tokio_core::net::{UdpSocket, UdpCodec, UdpFramed}; +use tokio_core::reactor::Handle; +use tokio_io::codec::Framed; +use tokio_timer::{Interval, Timer}; +use snow::NoiseBuilder; + + +pub type PeerServerMessage = (SocketAddr, Vec<u8>); +struct VecUdpCodec; +impl UdpCodec for VecUdpCodec { + type In = PeerServerMessage; + type Out = PeerServerMessage; + + fn decode(&mut self, src: &SocketAddr, buf: &[u8]) -> io::Result<Self::In> { + Ok((*src, buf.to_vec())) + } + + fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr { + let (addr, mut data) = msg; + buf.append(&mut data); + addr + } +} + + +pub struct PeerServer { + handle: Handle, + shared_state: SharedState, + udp_stream: stream::SplitStream<UdpFramed<VecUdpCodec>>, + rx: unsync::mpsc::Receiver<Vec<u8>>, + udp_tx: unsync::mpsc::Sender<(SocketAddr, Vec<u8>)>, + tunnel_tx: unsync::mpsc::Sender<Vec<u8>>, + pub tx: unsync::mpsc::Sender<Vec<u8>>, +} + +impl PeerServer { + pub fn bind(handle: Handle, shared_state: SharedState, tunnel_tx: unsync::mpsc::Sender<Vec<u8>>) -> Self { + let socket = UdpSocket::bind(&([0,0,0,0], 0).into(), &handle.clone()).unwrap(); + let (udp_sink, udp_stream) = socket.framed(VecUdpCodec{}).split(); + let (udp_tx, udp_rx) = unsync::mpsc::channel::<(SocketAddr, Vec<u8>)>(1024); + let (tx, rx) = unsync::mpsc::channel::<Vec<u8>>(1024); + + let udp_write_passthrough = udp_sink.sink_map_err(|_| ()).send_all( + udp_rx.map(|(addr, packet)| { + debug_packet("sending UDP: ", &packet); + (addr, packet) + }).map_err(|_| ())) + .then(|_| Ok(())); + handle.spawn(udp_write_passthrough); + + PeerServer { + handle, shared_state, udp_stream, udp_tx, tunnel_tx, tx, rx + } + } + + pub fn tx(&self) -> unsync::mpsc::Sender<Vec<u8>> { + self.tx.clone() + } + + pub fn udp_tx(&self) -> unsync::mpsc::Sender<(SocketAddr, Vec<u8>)> { + self.udp_tx.clone() + } + + fn handle_incoming_packet(&mut self, addr: SocketAddr, packet: Vec<u8>) { + debug!("got a UDP packet of length {}, packet type {}", packet.len(), packet[0]); + let mut state = self.shared_state.borrow_mut(); + match packet[0] { + 1 => { + info!("got handshake initialization."); + }, + 2 => { + let their_index = LittleEndian::read_u32(&packet[4..]); + let our_index = LittleEndian::read_u32(&packet[8..]); + let peer_ref = state.index_map.get(&our_index).unwrap().clone(); + let mut peer = peer_ref.borrow_mut(); + peer.sessions.next.as_mut().unwrap().their_index = their_index; + let payload_len = peer.next_noise().expect("pending noise session") + .read_message(&packet[12..60], &mut []).unwrap(); + assert!(payload_len == 0); + peer.ratchet_session().unwrap(); + info!("got handshake response, ratcheted session."); + + let noise = NoiseBuilder::new("Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s".parse().unwrap()) + .local_private_key(&state.interface_info.private_key.expect("no private key!")) + .remote_public_key(&peer.info.pub_key) + .prologue("WireGuard v1 zx2c4 Jason@zx2c4.com".as_bytes()) + .psk(2, &peer.info.psk.expect("no psk!")) + .build_initiator().unwrap(); + peer.set_next_session(noise.into()); + + let _ = state.index_map.insert(peer.our_next_index().unwrap(), peer_ref.clone()); + + let init_packet = peer.get_handshake_packet(); + let endpoint = peer.info.endpoint.unwrap().clone(); + + let timer = Timer::default(); + let sleep = timer.sleep(Duration::from_secs(120)); + let boop = sleep.and_then({ + let handle = self.handle.clone(); + let tx = self.udp_tx.clone(); + let peer_ref = peer_ref.clone(); + move |_| { + info!("sending rekey!"); + handle.spawn(tx.clone().send((endpoint, init_packet)) + .map(|_| ()) + .map_err(|_| ())); + Ok(()) + } + }).map_err(|_|()); + self.handle.spawn(boop); + }, + 4 => { + let our_index_received = LittleEndian::read_u32(&packet[4..]); + let nonce = LittleEndian::read_u64(&packet[8..]); + + let mut raw_packet = [0u8; 1500]; + let lookup = state.index_map.get(&our_index_received); + if let Some(ref peer) = lookup { + let mut peer = peer.borrow_mut(); + + peer.rx_bytes += packet.len(); + let noise = peer.current_noise().expect("current noise session"); + noise.set_receiving_nonce(nonce).unwrap(); + let payload_len = noise.read_message(&packet[16..], &mut raw_packet).unwrap(); + debug_packet("received TRANSPORT: ", &raw_packet[..payload_len]); + self.handle.spawn(self.tunnel_tx.clone().send(raw_packet[..payload_len].to_owned()) + .map(|_| ()) + .map_err(|_| ())); + } + }, + _ => unimplemented!() + } + } + + fn handle_outgoing_packet(&mut self, packet: Vec<u8>) { + debug!("handle_outgoing_packet()"); + + } +} + +impl Future for PeerServer { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<Self::Item, Self::Error> { + // Handle UDP packets from the outside world + loop { + match self.udp_stream.poll() { + Ok(Async::Ready(Some((addr, packet)))) => self.handle_incoming_packet(addr, packet), + Ok(Async::NotReady) => break, + Ok(Async::Ready(None)) | Err(_) => return Err(()), + } + } + + // Handle packets coming from the local tunnel + loop { + match self.rx.poll() { + Ok(Async::Ready(Some(packet))) => self.handle_outgoing_packet(packet), + Ok(Async::NotReady) => break, + Ok(Async::Ready(None)) | Err(_) => return Err(()), + } + } + + Ok(Async::NotReady) + } +} diff --git a/src/main.rs b/src/main.rs index 0f3f9e4..240e6e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,19 @@ #![allow(unused_imports)] #[macro_use] extern crate log; -extern crate env_logger; - #[macro_use] extern crate structopt_derive; #[macro_use] extern crate error_chain; +#[macro_use] extern crate futures; - +extern crate env_logger; extern crate daemonize; extern crate rand; extern crate nix; extern crate structopt; -extern crate mio; - extern crate bytes; -extern crate futures; extern crate tokio_core; extern crate tokio_io; -extern crate tokio_proto; -extern crate tokio_service; extern crate tokio_uds; extern crate tokio_utun; extern crate tokio_timer; @@ -73,7 +67,6 @@ fn main() { } Interface::new(&opt.interface).start(); -// WireGuard::start(interface_name).expect("failed to start WireGuard interface"); } fn daemonize() -> Result<()> { |