aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-13 14:32:40 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-13 14:32:40 +0200
commit93e3848ea76e755477bec8d9540a3c4c31ea7320 (patch)
tree31c27266ebf12fa9cef06ab531ee4b9fa7b69c56
parentRestructured MAC/cookie calculation (diff)
downloadwireguard-go-93e3848ea76e755477bec8d9540a3c4c31ea7320.tar.xz
wireguard-go-93e3848ea76e755477bec8d9540a3c4c31ea7320.zip
Terminate on interface deletion
Program now terminates when the interface is removed Increases the number of os threads (relevant for Go <1.5, not tested) More consistent commenting Improved logging (additional peer information)
Diffstat (limited to '')
-rw-r--r--src/constants.go4
-rw-r--r--src/device.go7
-rw-r--r--src/ip.go4
-rw-r--r--src/main.go31
-rw-r--r--src/peer.go19
-rw-r--r--src/receive.go24
-rw-r--r--src/send.go69
-rw-r--r--src/timers.go52
-rw-r--r--src/trie.go19
9 files changed, 132 insertions, 97 deletions
diff --git a/src/constants.go b/src/constants.go
index 0384741..6b0d414 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -29,6 +29,6 @@ const (
QueueInboundSize = 1024
QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8
- MinMessageSize = MessageTransportSize // keep-alive
- MaxMessageSize = 4096 // TODO: make depend on the MTU?
+ MinMessageSize = MessageTransportSize // size of keep-alive
+ MaxMessageSize = (1 << 16) - 1
)
diff --git a/src/device.go b/src/device.go
index a26cc7b..b272544 100644
--- a/src/device.go
+++ b/src/device.go
@@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineBusyMonitor()
+ go device.RoutineWriteToTUN(tun)
go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming()
- go device.RoutineWriteToTUN(tun)
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device
@@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.signal.stop)
- close(device.queue.encryption)
+}
+
+func (device *Device) Wait() {
+ <-device.signal.stop
}
diff --git a/src/ip.go b/src/ip.go
index 36beb9c..752a404 100644
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,17 +5,13 @@ import (
)
const (
- IPv4version = 4
IPv4offsetTotalLength = 2
IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
- IPv4headerSize = 20
)
const (
- IPv6version = 6
IPv6offsetPayloadLength = 4
IPv6offsetSrc = 8
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
- IPv6headerSize = 40
)
diff --git a/src/main.go b/src/main.go
index 50140e3..dc27472 100644
--- a/src/main.go
+++ b/src/main.go
@@ -5,6 +5,7 @@ import (
"log"
"net"
"os"
+ "runtime"
)
/* TODO: Fix logging
@@ -18,6 +19,10 @@ func main() {
}
deviceName := os.Args[1]
+ // increase number of go workers (for Go <1.5)
+
+ runtime.GOMAXPROCS(runtime.NumCPU())
+
// open TUN device
tun, err := CreateTUN(deviceName)
@@ -31,17 +36,21 @@ func main() {
// start configuration lister
- socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
- l, err := net.Listen("unix", socketPath)
- if err != nil {
- log.Fatal("listen error:", err)
- }
-
- for {
- conn, err := l.Accept()
+ go func() {
+ socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
+ l, err := net.Listen("unix", socketPath)
if err != nil {
- log.Fatal("accept error:", err)
+ log.Fatal("listen error:", err)
}
- go ipcHandle(device, conn)
- }
+
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ log.Fatal("accept error:", err)
+ }
+ go ipcHandle(device, conn)
+ }
+ }()
+
+ device.Wait()
}
diff --git a/src/peer.go b/src/peer.go
index c8dc5c0..408c605 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -1,7 +1,9 @@
package main
import (
+ "encoding/base64"
"errors"
+ "fmt"
"net"
"sync"
"time"
@@ -38,9 +40,9 @@ type Peer struct {
/* Both keep-alive timers acts as one (see timers.go)
* They are kept seperate to simplify the implementation.
*/
- keepalivePersistent *time.Timer // set for persistent keepalives
- keepaliveAcknowledgement *time.Timer // set upon recieving messages
- zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3
+ keepalivePersistent *time.Timer // set for persistent keepalives
+ keepalivePassive *time.Timer // set upon recieving messages
+ zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
@@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.mac.Init(pk)
peer.device = device
+ peer.timer.keepalivePassive = NewStoppedTimer()
peer.timer.keepalivePersistent = NewStoppedTimer()
- peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
peer.timer.zeroAllKeys = NewStoppedTimer()
peer.flags.keepaliveWaiting = AtomicFalse
@@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
return peer
}
+func (peer *Peer) String() string {
+ return fmt.Sprintf(
+ "peer(%d %s %s)",
+ peer.id,
+ peer.endpoint.String(),
+ base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+ )
+}
+
func (peer *Peer) Close() {
close(peer.signal.stop)
}
diff --git a/src/receive.go b/src/receive.go
index 99089a9..3e649b6 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
"net"
"sync"
"sync/atomic"
@@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
return
}
- logDebug.Println("Creating response...")
+ logDebug.Println("Creating response message for", peer.String())
outElem := device.NewOutboundElement()
writer := bytes.NewBuffer(outElem.data[:0])
@@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
var elem *QueueInboundElement
device := peer.device
+
+ logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
@@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.KeepKeyFreshReceiving()
- // check if confirming handshake
+ // check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock()
@@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for keep-alive
if len(elem.packet) == 0 {
+ logDebug.Println("Received keep-alive from", peer.String())
return
}
// verify source and strip padding
switch elem.packet[0] >> 4 {
- case IPv4version:
+ case ipv4.Version:
// strip padding
- if len(elem.packet) < IPv4headerSize {
+ if len(elem.packet) < ipv4.HeaderLen {
return
}
@@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
if device.routingTable.LookupIPv4(dst) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
return
}
- case IPv6version:
+ case ipv6.Version:
// strip padding
- if len(elem.packet) < IPv6headerSize {
+ if len(elem.packet) < ipv6.HeaderLen {
return
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
- length += IPv6headerSize
+ length += ipv6.HeaderLen
elem.packet = elem.packet[:length]
// verify IPv6 source
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
if device.routingTable.LookupIPv6(dst) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
return
}
default:
- logDebug.Println("Receieved packet with unknown IP version")
+ logInfo.Println("Packet with invalid IP version from", peer.String())
return
}
@@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
+
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential tun writer, started")
diff --git a/src/send.go b/src/send.go
index 5ea9a8f..d8ddc82 100644
--- a/src/send.go
+++ b/src/send.go
@@ -3,6 +3,8 @@ package main
import (
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
"net"
"sync"
"sync/atomic"
@@ -21,28 +23,26 @@ import (
* The functions in this file occure (roughly) in the order packets are processed.
*/
-/* A work unit
- *
- * The sequential consumers will attempt to take the lock,
- * workers release lock when they have completed work on the packet.
+/* The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work (encryption) on the packet.
*
* If the element is inserted into the "encryption queue",
- * the content is preceeded by enough "junk" to contain the header
+ * the content is preceeded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place)
*/
type QueueOutboundElement struct {
dropped int32
mutex sync.Mutex
- data [MaxMessageSize]byte
- packet []byte // slice of "data" (always!)
- nonce uint64 // nonce for encryption
- keyPair *KeyPair // key-pair for encryption
- peer *Peer // related peer
+ data [MaxMessageSize]byte // slice holding the packet data
+ packet []byte // slice of "data" (always!)
+ nonce uint64 // nonce for encryption
+ keyPair *KeyPair // key-pair for encryption
+ peer *Peer // related peer
}
func (peer *Peer) FlushNonceQueue() {
elems := len(peer.queue.nonce)
- for i := 0; i < elems; i += 1 {
+ for i := 0; i < elems; i++ {
select {
case <-peer.queue.nonce:
default:
@@ -111,14 +111,18 @@ func addToEncryptionQueue(
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+
if tun == nil {
- // dummy
return
}
elem := device.NewOutboundElement()
- device.log.Debug.Println("Routine, TUN Reader: started")
+ logDebug := device.log.Debug
+ logError := device.log.Error
+
+ logDebug.Println("Routine, TUN Reader: started")
+
for {
// read packet
@@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
elem.packet = elem.data[MessageTransportHeaderSize:]
size, err := tun.Read(elem.packet)
if err != nil {
- device.log.Error.Println("Failed to read packet from TUN device:", err)
- continue
+
+ // stop process
+
+ logError.Println("Failed to read packet from TUN device:", err)
+ device.Close()
+ return
}
+
elem.packet = elem.packet[:size]
- if len(elem.packet) < IPv4headerSize {
- device.log.Error.Println("Packet too short, length:", size)
+ if len(elem.packet) < ipv4.HeaderLen {
+ logError.Println("Packet too short, length:", size)
continue
}
@@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
var peer *Peer
switch elem.packet[0] >> 4 {
- case IPv4version:
+ case ipv4.Version:
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
- case IPv6version:
+ case ipv6.Version:
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
default:
- device.log.Debug.Println("Receieved packet with unknown IP version")
+ logDebug.Println("Receieved packet with unknown IP version")
}
if peer == nil {
continue
}
+
if peer.endpoint == nil {
- device.log.Debug.Println("No known endpoint for peer", peer.id)
+ logDebug.Println("No known endpoint for peer", peer.String())
continue
}
@@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
device := peer.device
logDebug := device.log.Debug
- logDebug.Println("Routine, nonce worker, started for peer", peer.id)
+ logDebug.Println("Routine, nonce worker, started for peer", peer.String())
func() {
@@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
}
}
signalSend(peer.signal.handshakeBegin)
- logDebug.Println("Waiting for key-pair, peer", peer.id)
+ logDebug.Println("Awaiting key-pair for", peer.String())
select {
case <-peer.signal.newKeyPair:
- logDebug.Println("Key-pair negotiated for peer", peer.id)
+ logDebug.Println("Key-pair negotiated for", peer.String())
goto NextPacket
case <-peer.signal.flushNonceQueue:
- logDebug.Println("Clearing queue for peer", peer.id)
+ logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue()
elem = nil
goto NextPacket
@@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
device := peer.device
logDebug := device.log.Debug
- logDebug.Println("Routine, sequential sender, started for peer", peer.id)
+ logDebug.Println("Routine, sequential sender, started for", peer.String())
for {
select {
case <-peer.signal.stop:
- logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
+ logDebug.Println("Routine, sequential sender, stopped for", peer.String())
return
+
case work := <-peer.queue.outbound:
work.mutex.Lock()
if work.IsDropped() {
@@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
- logDebug.Println("No endpoint for peer:", peer.id)
+ logDebug.Println("No endpoint for", peer.String())
return
}
@@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
}
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
- // reset keep-alive (passive keep-alives / acknowledgements)
+ // reset keep-alive
peer.TimerResetKeepalive()
}()
diff --git a/src/timers.go b/src/timers.go
index 6393955..2e5046e 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
* - First transport message under the "next" key
*/
func (peer *Peer) EventHandshakeComplete() {
- peer.device.log.Debug.Println("Handshake completed")
+ peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
signalSend(peer.signal.handshakeCompleted)
}
@@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
// stop acknowledgement timer
- timerStop(peer.timer.keepaliveAcknowledgement)
+ timerStop(peer.timer.keepalivePassive)
atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
}
@@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
device := peer.device
logDebug := device.log.Debug
- logDebug.Println("Routine, timer handler, started for peer", peer.id)
+ logDebug.Println("Routine, timer handler, started for peer", peer.String())
for {
select {
@@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.timer.keepalivePersistent.C:
- logDebug.Println("Sending persistent keep-alive to peer", peer.id)
+ logDebug.Println("Sending persistent keep-alive to", peer.String())
peer.SendKeepAlive()
peer.TimerResetKeepalive()
- case <-peer.timer.keepaliveAcknowledgement.C:
+ case <-peer.timer.keepalivePassive.C:
- logDebug.Println("Sending passive persistent keep-alive to peer", peer.id)
+ logDebug.Println("Sending passive persistent keep-alive to", peer.String())
peer.SendKeepAlive()
peer.TimerResetKeepalive()
@@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.timer.zeroAllKeys.C:
- logDebug.Println("Clearing all key material for peer", peer.id)
+ logDebug.Println("Clearing all key material for", peer.String())
// zero out key pairs
@@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
var elem *QueueOutboundElement
+ logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
- logDebug.Println("Routine, handshake initator, started for peer", peer.id)
+ logDebug.Println("Routine, handshake initator, started for", peer.String())
- for run := true; run; {
- var err error
- var attempts uint
- var deadline time.Time
+ for {
// wait for signal
@@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
// wait for handshake
- run = func() bool {
- for {
+ func() {
+ var err error
+ var deadline time.Time
+ for attempts := uint(1); ; attempts++ {
// clear completed signal
select {
case <-peer.signal.handshakeCompleted:
case <-peer.signal.stop:
- return false
+ return
default:
}
@@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
}
elem, err = peer.BeginHandshakeInitiation()
if err != nil {
- logError.Println("Failed to create initiation message:", err)
- break
+ logError.Println("Failed to create initiation message", err, "for", peer.String())
+ return
}
// set timeout
- attempts += 1
if attempts == 1 {
deadline = time.Now().Add(MaxHandshakeAttemptTime)
}
timeout := time.NewTimer(RekeyTimeout)
- logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
+ logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
// wait for handshake or timeout
select {
+
case <-peer.signal.stop:
- return true
+ return
case <-peer.signal.handshakeCompleted:
<-timeout.C
- return true
+ return
case <-timeout.C:
- logDebug.Println("Timeout")
-
- // check if sufficient time for retry
-
if deadline.Before(time.Now().Add(RekeyTimeout)) {
+ logInfo.Println("Handshake negotiation timed out for", peer.String())
signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent)
- timerStop(peer.timer.keepaliveAcknowledgement)
- return true
+ timerStop(peer.timer.keepalivePassive)
+ return
}
}
}
- return true
}()
signalClear(peer.signal.handshakeBegin)
diff --git a/src/trie.go b/src/trie.go
index c2304b2..e81b5b6 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -23,7 +23,8 @@ type Trie struct {
bits []byte
peer *Peer
- // Index of "branching" bit
+ // index of "branching" bit
+
bit_at_byte uint
bit_at_shift uint
}
@@ -36,7 +37,7 @@ type Trie struct {
func commonBits(ip1 net.IP, ip2 net.IP) uint {
var i uint
size := uint(len(ip1))
- for i = 0; i < size; i += 1 {
+ for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i]
if v != 0 {
v >>= 1
@@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node
}
- // Walk recursivly
+ // walk recursivly
node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p)
@@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node
}
- // Remove peer & merge
+ // remove peer & merge
node.peer = nil
if node.child[0] == nil {
@@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
- // At leaf
+ // at leaf
if node == nil {
return &Trie{
@@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
}
}
- // Traverse deeper
+ // traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
@@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return node
}
- // Split node
+ // split node
newNode := &Trie{
bits: ip,
@@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
cidr = min(cidr, common)
- // Check for shorter prefix
+ // check for shorter prefix
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
@@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return newNode
}
- // Create new parent for node & newNode
+ // create new parent for node & newNode
parent := &Trie{
bits: ip,