aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-09-09 15:03:01 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-09-09 15:03:01 +0200
commitf212795e51d839910085e08f9c6b09eac11863d3 (patch)
tree8ea6dba582442e46c1b21fe58f52be20f02c5fed /src
parentFixed TUN interface implementation os OS X (diff)
downloadwireguard-go-f212795e51d839910085e08f9c6b09eac11863d3.tar.xz
wireguard-go-f212795e51d839910085e08f9c6b09eac11863d3.zip
Improved readability of send/receive code
Diffstat (limited to '')
-rw-r--r--src/receive.go229
-rw-r--r--src/send.go167
2 files changed, 178 insertions, 218 deletions
diff --git a/src/receive.go b/src/receive.go
index 97646d8..09fca77 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -128,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() {
// read next datagram
- size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes
+ size, raddr, err := conn.ReadFromUDP(buffer[:])
if err != nil {
break
@@ -222,7 +222,7 @@ func (device *Device) RoutineReceiveIncomming() {
}
func (device *Device) RoutineDecryption() {
- var elem *QueueInboundElement
+
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
@@ -230,50 +230,51 @@ func (device *Device) RoutineDecryption() {
for {
select {
- case elem = <-device.queue.decryption:
case <-device.signal.stop:
+ logDebug.Println("Routine, decryption worker, stopped")
return
- }
- // check if dropped
+ case elem := <-device.queue.decryption:
- if elem.IsDropped() {
- continue
- }
+ // check if dropped
- // split message into fields
-
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
-
- // decrypt with key-pair
-
- var err error
- copy(nonce[4:], counter)
- elem.counter = binary.LittleEndian.Uint64(counter)
- elem.keyPair.receive.mutex.RLock()
- if elem.keyPair.receive.aead == nil {
- // very unlikely (the key was deleted during queuing)
- elem.Drop()
- } else {
- elem.packet, err = elem.keyPair.receive.aead.Open(
- elem.buffer[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
+ if elem.IsDropped() {
+ continue
+ }
+
+ // split message into fields
+
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // decrypt with key-pair
+
+ copy(nonce[4:], counter)
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ elem.keyPair.receive.mutex.RLock()
+ if elem.keyPair.receive.aead == nil {
+ // very unlikely (the key was deleted during queuing)
elem.Drop()
+ } else {
+ var err error
+ elem.packet, err = elem.keyPair.receive.aead.Open(
+ elem.buffer[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.Drop()
+ }
}
+
+ elem.keyPair.receive.mutex.RUnlock()
+ elem.mutex.Unlock()
}
- elem.keyPair.receive.mutex.RUnlock()
- elem.mutex.Unlock()
}
}
/* Handles incomming packets related to handshake
- *
- *
*/
func (device *Device) RoutineHandshake() {
@@ -473,7 +474,6 @@ func (device *Device) RoutineHandshake() {
}
func (peer *Peer) RoutineSequentialReceiver() {
- var elem *QueueInboundElement
device := peer.device
@@ -483,118 +483,119 @@ func (peer *Peer) RoutineSequentialReceiver() {
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
for {
- // wait for decryption
select {
case <-peer.signal.stop:
+ logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
return
- case elem = <-peer.queue.inbound:
- }
- elem.mutex.Lock()
- // process packet
+ case elem := <-peer.queue.inbound:
- if elem.IsDropped() {
- continue
- }
+ // wait for decryption
+
+ elem.mutex.Lock()
+ if elem.IsDropped() {
+ continue
+ }
- // check for replay
+ // check for replay
- if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
- continue
- }
+ if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+ continue
+ }
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
- peer.KeepKeyFreshReceiving()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
+ peer.KeepKeyFreshReceiving()
- // check if using new key-pair
+ // check if using new key-pair
- kp := &peer.keyPairs
- kp.mutex.Lock()
- if kp.next == elem.keyPair {
- peer.TimerHandshakeComplete()
- if kp.previous != nil {
- device.DeleteKeyPair(kp.previous)
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+ if kp.next == elem.keyPair {
+ peer.TimerHandshakeComplete()
+ if kp.previous != nil {
+ device.DeleteKeyPair(kp.previous)
+ }
+ kp.previous = kp.current
+ kp.current = kp.next
+ kp.next = nil
}
- kp.previous = kp.current
- kp.current = kp.next
- kp.next = nil
- }
- kp.mutex.Unlock()
+ kp.mutex.Unlock()
- // check for keep-alive
+ // check for keep-alive
- if len(elem.packet) == 0 {
- logDebug.Println("Received keep-alive from", peer.String())
- continue
- }
- peer.TimerDataReceived()
+ if len(elem.packet) == 0 {
+ logDebug.Println("Received keep-alive from", peer.String())
+ continue
+ }
+ peer.TimerDataReceived()
- // verify source and strip padding
+ // verify source and strip padding
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
- // strip padding
+ // strip padding
- if len(elem.packet) < ipv4.HeaderLen {
- continue
- }
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- continue
- }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
- elem.packet = elem.packet[:length]
+ elem.packet = elem.packet[:length]
- // verify IPv4 source
+ // verify IPv4 source
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.routingTable.LookupIPv4(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
- continue
- }
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.routingTable.LookupIPv4(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- case ipv6.Version:
+ case ipv6.Version:
- // strip padding
+ // strip padding
- if len(elem.packet) < ipv6.HeaderLen {
- continue
- }
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- continue
- }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
- elem.packet = elem.packet[:length]
+ elem.packet = elem.packet[:length]
- // verify IPv6 source
+ // verify IPv6 source
+
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.routingTable.LookupIPv6(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.routingTable.LookupIPv6(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
+ default:
+ logInfo.Println("Packet with invalid IP version from", peer.String())
continue
}
- default:
- logInfo.Println("Packet with invalid IP version from", peer.String())
- continue
- }
-
- // write to tun
+ // write to tun
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
- _, err := device.tun.device.Write(elem.packet)
- device.PutMessageBuffer(elem.buffer)
- if err != nil {
- logError.Println("Failed to write packet to TUN device:", err)
+ atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ _, err := device.tun.device.Write(elem.packet)
+ device.PutMessageBuffer(elem.buffer)
+ if err != nil {
+ logError.Println("Failed to write packet to TUN device:", err)
+ }
}
}
}
diff --git a/src/send.go b/src/send.go
index c598ad4..e9dfb54 100644
--- a/src/send.go
+++ b/src/send.go
@@ -35,7 +35,7 @@ type QueueOutboundElement struct {
dropped int32
mutex sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
- packet []byte // slice of "data" (always!)
+ packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
keyPair *KeyPair // key-pair for encryption
peer *Peer // related peer
@@ -52,11 +52,6 @@ func (peer *Peer) FlushNonceQueue() {
}
}
-var (
- ErrorNoEndpoint = errors.New("No known endpoint for peer")
- ErrorNoConnection = errors.New("No UDP socket for device")
-)
-
func (device *Device) NewOutboundElement() *QueueOutboundElement {
return &QueueOutboundElement{
dropped: AtomicFalse,
@@ -118,14 +113,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
defer peer.mutex.RUnlock()
endpoint := peer.endpoint
- conn := peer.device.net.conn
-
if endpoint == nil {
- return 0, ErrorNoEndpoint
+ return 0, errors.New("No known endpoint for peer")
}
+ conn := peer.device.net.conn
if conn == nil {
- return 0, ErrorNoConnection
+ return 0, errors.New("No UDP socket for device")
}
return conn.WriteToUDP(buffer, endpoint)
@@ -189,16 +183,6 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
- // check if known endpoint (drop early)
-
- peer.mutex.RLock()
- if peer.endpoint == nil {
- peer.mutex.RUnlock()
- logDebug.Println("No known endpoint for peer", peer.String())
- continue
- }
- peer.mutex.RUnlock()
-
// insert into nonce/pre-handshake queue
signalSend(peer.signal.handshakeReset)
@@ -211,86 +195,61 @@ func (device *Device) RoutineReadFromTUN() {
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
*
- * TODO: Avoid dynamic allocation of work queue elements
- *
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair
- var elem *QueueOutboundElement
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.String())
- func() {
-
- for {
- NextPacket:
-
- // wait for packet
+ for {
+ NextPacket:
+ select {
+ case <-peer.signal.stop:
+ return
- if elem == nil {
- select {
- case elem = <-peer.queue.nonce:
- case <-peer.signal.stop:
- return
- }
- }
+ case elem := <-peer.queue.nonce:
// wait for key pair
for {
- select {
- case <-peer.signal.newKeyPair:
- default:
- }
-
keyPair = peer.keyPairs.Current()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime {
break
}
}
+
signalSend(peer.signal.handshakeBegin)
logDebug.Println("Awaiting key-pair for", peer.String())
select {
case <-peer.signal.newKeyPair:
- logDebug.Println("Key-pair negotiated for", peer.String())
- goto NextPacket
-
case <-peer.signal.flushNonceQueue:
logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue()
- elem = nil
goto NextPacket
-
case <-peer.signal.stop:
return
}
}
- // process current packet
+ // populate work element
- if elem != nil {
-
- // create work element
-
- elem.keyPair = keyPair
- elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
- elem.dropped = AtomicFalse
- elem.peer = peer
- elem.mutex.Lock()
+ elem.peer = peer
+ elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+ elem.keyPair = keyPair
+ elem.dropped = AtomicFalse
+ elem.mutex.Lock()
- // add to parallel and sequential queue
+ // add to parallel and sequential queue
- addToEncryptionQueue(device.queue.encryption, elem)
- addToOutboundQueue(peer.queue.outbound, elem)
- elem = nil
- }
+ addToEncryptionQueue(device.queue.encryption, elem)
+ addToOutboundQueue(peer.queue.outbound, elem)
}
- }()
+ }
}
/* Encrypts the elements in the queue
@@ -300,7 +259,6 @@ func (peer *Peer) RoutineNonce() {
*/
func (device *Device) RoutineEncryption() {
- var elem *QueueOutboundElement
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
@@ -311,62 +269,62 @@ func (device *Device) RoutineEncryption() {
// fetch next element
select {
- case elem = <-device.queue.encryption:
case <-device.signal.stop:
logDebug.Println("Routine, encryption worker, stopped")
return
- }
- // check if dropped
+ case elem := <-device.queue.encryption:
- if elem.IsDropped() {
- continue
- }
+ // check if dropped
+
+ if elem.IsDropped() {
+ continue
+ }
- // populate header fields
+ // populate header fields
- header := elem.buffer[:MessageTransportHeaderSize]
+ header := elem.buffer[:MessageTransportHeaderSize]
- fieldType := header[0:4]
- fieldReceiver := header[4:8]
- fieldNonce := header[8:16]
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
- // pad content to MTU size
+ // pad content to multiple of 16
- mtu := int(atomic.LoadInt32(&device.tun.mtu))
- pad := len(elem.packet) % PaddingMultiple
- if pad > 0 {
- for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
- elem.packet = append(elem.packet, 0)
+ mtu := int(atomic.LoadInt32(&device.tun.mtu))
+ rem := len(elem.packet) % PaddingMultiple
+ if rem > 0 {
+ for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ {
+ elem.packet = append(elem.packet, 0)
+ }
}
- // TODO: How good is this code
- }
- // encrypt content (append to header)
-
- binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.keyPair.send.mutex.RLock()
- if elem.keyPair.send.aead == nil {
- // very unlikely (the key was deleted during queuing)
- elem.Drop()
- } else {
- elem.packet = elem.keyPair.send.aead.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
- }
- elem.keyPair.send.mutex.RUnlock()
- elem.mutex.Unlock()
+ // encrypt content (append to header)
+
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.keyPair.send.mutex.RLock()
+ if elem.keyPair.send.aead == nil {
+ // very unlikely (the key was deleted during queuing)
+ elem.Drop()
+ } else {
+ elem.packet = elem.keyPair.send.aead.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ }
+ elem.mutex.Unlock()
+ elem.keyPair.send.mutex.RUnlock()
- // refresh key if necessary
+ // refresh key if necessary
- elem.peer.KeepKeyFreshSending()
+ elem.peer.KeepKeyFreshSending()
+ }
}
}
@@ -399,6 +357,7 @@ func (peer *Peer) RoutineSequentialSender() {
_, err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
+ logDebug.Println("Failed to send authenticated packet to peer", peer.String())
continue
}
atomic.AddUint64(&peer.stats.txBytes, length)