aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-06 15:43:55 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-06 15:43:55 +0200
commit59f9316f51ce3cb470200b0cfe847116a0583d25 (patch)
treee9cfb69aa8b58d8b009167730713c4bde67d7cd4
parentFixed transport header problem (diff)
downloadwireguard-go-59f9316f51ce3cb470200b0cfe847116a0583d25.tar.xz
wireguard-go-59f9316f51ce3cb470200b0cfe847116a0583d25.zip
Initial working full exchange
The implementation is now capable of connecting to another wireguard instance, complete a handshake and exchange transport messages.
-rw-r--r--src/constants.go3
-rw-r--r--src/device.go2
-rw-r--r--src/handshake.go29
-rw-r--r--src/keypair.go3
-rw-r--r--src/noise_protocol.go8
-rw-r--r--src/peer.go6
-rw-r--r--src/receive.go150
-rw-r--r--src/send.go188
8 files changed, 186 insertions, 203 deletions
diff --git a/src/constants.go b/src/constants.go
index 2e484f3..053ba4f 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -21,5 +21,6 @@ const (
QueueInboundSize = 1024
QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8
- MinMessageSize = MessageTransportSize
+ MinMessageSize = MessageTransportSize // keep-alive
+ MaxMessageSize = 4096
)
diff --git a/src/device.go b/src/device.go
index ff10e32..a317122 100644
--- a/src/device.go
+++ b/src/device.go
@@ -80,6 +80,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
+ device.queue.inbound = make(chan []byte, QueueInboundSize)
// prepare signals
@@ -94,6 +95,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming()
+ go device.RoutineWriteToTUN(tun)
return device
}
diff --git a/src/handshake.go b/src/handshake.go
index cf73e9b..88bb8cb 100644
--- a/src/handshake.go
+++ b/src/handshake.go
@@ -12,9 +12,11 @@ import (
* Used by initiator of handshake and with active keep-alive
*/
func (peer *Peer) SendKeepAlive() bool {
+ elem := peer.device.NewOutboundElement()
+ elem.packet = nil
if len(peer.queue.nonce) == 0 {
select {
- case peer.queue.nonce <- []byte{}:
+ case peer.queue.nonce <- elem:
return true
default:
return false
@@ -60,11 +62,10 @@ func (peer *Peer) KeepKeyFreshSending() {
*/
func (peer *Peer) RoutineHandshakeInitiator() {
device := peer.device
- buffer := make([]byte, 1024)
logger := device.log.Debug
timeout := stoppedTimer()
- var work *QueueOutboundElement
+ var elem *QueueOutboundElement
logger.Println("Routine, handshake initator, started for peer", peer.id)
@@ -94,25 +95,25 @@ func (peer *Peer) RoutineHandshakeInitiator() {
// create initiation
- if work != nil {
- work.mutex.Lock()
- work.packet = nil
- work.mutex.Unlock()
+ if elem != nil {
+ elem.Drop()
}
- work = new(QueueOutboundElement)
+ elem = device.NewOutboundElement()
+
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
break
}
- // schedule for sending
+ // marshal & schedule for sending
- writer := bytes.NewBuffer(buffer[:0])
+ writer := bytes.NewBuffer(elem.data[:0])
binary.Write(writer, binary.LittleEndian, msg)
- work.packet = writer.Bytes()
- peer.mac.AddMacs(work.packet)
- peer.InsertOutbound(work)
+ elem.packet = writer.Bytes()
+ peer.mac.AddMacs(elem.packet)
+ println(elem)
+ addToOutboundQueue(peer.queue.outbound, elem)
if attempts == 0 {
deadline = time.Now().Add(MaxHandshakeAttemptTime)
@@ -132,9 +133,11 @@ func (peer *Peer) RoutineHandshakeInitiator() {
return
case <-peer.signal.handshakeCompleted:
+ device.log.Debug.Println("Handshake complete")
break HandshakeLoop
case <-timeout.C:
+ device.log.Debug.Println("Timeout")
if deadline.Before(time.Now().Add(RekeyTimeout)) {
peer.signal.flushNonceQueue <- struct{}{}
if !peer.timer.sendKeepalive.Stop() {
diff --git a/src/keypair.go b/src/keypair.go
index 0fac5cb..3caa0c8 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -7,8 +7,7 @@ import (
)
type KeyPair struct {
- recv cipher.AEAD
- recvNonce uint64
+ receive cipher.AEAD
send cipher.AEAD
sendNonce uint64
isInitiator bool
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 5a62901..0258288 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -446,10 +446,10 @@ func (peer *Peer) NewKeyPair() *KeyPair {
keyPair := new(KeyPair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
- keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
+ keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
- keyPair.recvNonce = 0
keyPair.created = time.Now()
+ keyPair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex
@@ -462,7 +462,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
})
handshake.localIndex = 0
- // start timer for keypair
+ // TODO: start timer for keypair (clearing)
// rotate key pairs
@@ -473,7 +473,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
if isInitiator {
if kp.previous != nil {
kp.previous.send = nil
- kp.previous.recv = nil
+ kp.previous.receive = nil
peer.device.indices.Delete(kp.previous.localIndex)
}
kp.previous = kp.current
diff --git a/src/peer.go b/src/peer.go
index 1c40598..77bd4df 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -35,7 +35,7 @@ type Peer struct {
handshakeTimeout *time.Timer
}
queue struct {
- nonce chan []byte // nonce / pre-handshake queue
+ nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
}
@@ -78,9 +78,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
// prepare queuing
- peer.queue.nonce = make(chan []byte, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+ peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signaling
diff --git a/src/receive.go b/src/receive.go
index 5afbf7f..50789a1 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -31,17 +31,39 @@ type QueueInboundElement struct {
func (elem *QueueInboundElement) Drop() {
atomic.StoreUint32(&elem.state, ElementStateDropped)
- elem.mutex.Unlock()
+}
+
+func (elem *QueueInboundElement) IsDropped() bool {
+ return atomic.LoadUint32(&elem.state) == ElementStateDropped
+}
+
+func addToInboundQueue(
+ queue chan *QueueInboundElement,
+ element *QueueInboundElement,
+) {
+ for {
+ select {
+ case queue <- element:
+ return
+ default:
+ select {
+ case old := <-queue:
+ old.Drop()
+ default:
+ }
+ }
+ }
}
func (device *Device) RoutineReceiveIncomming() {
- var packet []byte
debugLog := device.log.Debug
debugLog.Println("Routine, receive incomming, started")
errorLog := device.log.Error
+ var buffer []byte // unsliced buffer
+
for {
// check if stopped
@@ -54,28 +76,28 @@ func (device *Device) RoutineReceiveIncomming() {
// read next datagram
- if packet == nil {
- packet = make([]byte, 1<<16)
+ if buffer == nil {
+ buffer = make([]byte, MaxMessageSize)
}
device.net.mutex.RLock()
conn := device.net.conn
device.net.mutex.RUnlock()
+ if conn == nil {
+ time.Sleep(time.Second)
+ continue
+ }
conn.SetReadDeadline(time.Now().Add(time.Second))
- size, raddr, err := conn.ReadFromUDP(packet)
- if err != nil {
- continue
- }
- if size < MinMessageSize {
+ size, raddr, err := conn.ReadFromUDP(buffer)
+ if err != nil || size < MinMessageSize {
continue
}
// handle packet
- packet = packet[:size]
- debugLog.Println("GOT:", packet)
+ packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
func() {
@@ -112,6 +134,7 @@ func (device *Device) RoutineReceiveIncomming() {
// add to handshake queue
+ buffer = nil
device.queue.handshake <- QueueHandshakeElement{
msgType: msgType,
packet: packet,
@@ -137,8 +160,6 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageTransportType:
- debugLog.Println("DEBUG: Got transport")
-
// lookup key pair
if len(packet) < MessageTransportSize {
@@ -169,42 +190,15 @@ func (device *Device) RoutineReceiveIncomming() {
work.state = ElementStateOkay
work.mutex.Lock()
- // add to parallel decryption queue
-
- func() {
- for {
- select {
- case device.queue.decryption <- work:
- return
- default:
- select {
- case elem := <-device.queue.decryption:
- elem.Drop()
- default:
- }
- }
- }
- }()
-
- // add to sequential inbound queue
-
- func() {
- for {
- select {
- case peer.queue.inbound <- work:
- break
- default:
- select {
- case elem := <-peer.queue.inbound:
- elem.Drop()
- default:
- }
- }
- }
- }()
+ // add to decryption queues
+
+ addToInboundQueue(device.queue.decryption, work)
+ addToInboundQueue(peer.queue.inbound, work)
+ buffer = nil
default:
// unknown message type
+ debugLog.Println("Got unknown message from:", raddr)
}
}()
}
@@ -214,6 +208,9 @@ func (device *Device) RoutineDecryption() {
var elem *QueueInboundElement
var nonce [chacha20poly1305.NonceSize]byte
+ logDebug := device.log.Debug
+ logDebug.Println("Routine, decryption, started for device")
+
for {
select {
case elem = <-device.queue.decryption:
@@ -223,31 +220,25 @@ func (device *Device) RoutineDecryption() {
// check if dropped
- state := atomic.LoadUint32(&elem.state)
- if state != ElementStateOkay {
+ if elem.IsDropped() {
+ elem.mutex.Unlock()
continue
}
// split message into fields
- counter := binary.LittleEndian.Uint64(
- elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent],
- )
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt with key-pair
var err error
- binary.LittleEndian.PutUint64(nonce[4:], counter)
- elem.packet, err = elem.keyPair.recv.Open(elem.packet[:0], nonce[:], content, nil)
+ copy(nonce[4:], counter)
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ elem.packet, err = elem.keyPair.receive.Open(elem.packet[:0], nonce[:], content, nil)
if err != nil {
elem.Drop()
- continue
}
-
- // release to consumer
-
- elem.counter = counter
elem.mutex.Unlock()
}
}
@@ -261,6 +252,7 @@ func (device *Device) RoutineHandshake() {
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
+ logDebug.Println("Routine, handshake routine, started for device")
var elem QueueHandshakeElement
@@ -332,13 +324,15 @@ func (device *Device) RoutineHandshake() {
}
sendSignal(peer.signal.handshakeCompleted)
logDebug.Println("Recieved valid response message for peer", peer.id)
- peer.NewKeyPair()
+ kp := peer.NewKeyPair()
+ if kp == nil {
+ logDebug.Println("Failed to derieve key-pair")
+ }
peer.SendKeepAlive()
default:
device.log.Error.Println("Invalid message type in handshake queue")
}
-
}()
}
}
@@ -348,7 +342,6 @@ func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logDebug := device.log.Debug
-
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
for {
@@ -359,20 +352,15 @@ func (peer *Peer) RoutineSequentialReceiver() {
return
case elem = <-peer.queue.inbound:
}
- elem.mutex.Lock()
- // check if dropped
-
- logDebug.Println("MESSSAGE:", elem)
-
- state := atomic.LoadUint32(&elem.state)
- if state != ElementStateOkay {
+ elem.mutex.Lock()
+ if elem.IsDropped() {
continue
}
// check for replay
- // strip padding
+ // update timers
// check for keep-alive
@@ -380,26 +368,30 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
+ // strip padding
+
// insert into inbound TUN queue
device.queue.inbound <- elem.packet
- }
+ // update key material
+ }
}
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
- for {
- var packet []byte
+ logError := device.log.Error
+ logDebug := device.log.Debug
+ logDebug.Println("Routine, sequential tun writer, started")
+ for {
select {
case <-device.signal.stop:
- case packet = <-device.queue.inbound:
- }
-
- size, err := tun.Write(packet)
- device.log.Debug.Println("DEBUG:", size, err)
- if err != nil {
-
+ return
+ case packet := <-device.queue.inbound:
+ _, err := tun.Write(packet)
+ if err != nil {
+ logError.Println("Failed to write packet to TUN device:", err)
+ }
}
}
}
diff --git a/src/send.go b/src/send.go
index 3fe4733..4053669 100644
--- a/src/send.go
+++ b/src/send.go
@@ -25,14 +25,19 @@ import (
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work on the packet.
+ *
+ * If the element is inserted into the "encryption queue",
+ * the content is preceeded by enough "junk" to contain the header
+ * (to allow the constuction of transport messages in-place)
*/
type QueueOutboundElement struct {
state uint32
mutex sync.Mutex
- packet []byte
- nonce uint64
- keyPair *KeyPair
- peer *Peer
+ data [MaxMessageSize]byte
+ packet []byte // slice of packet (sending)
+ nonce uint64 // nonce for encryption
+ keyPair *KeyPair // key-pair for encryption
+ peer *Peer // related peer
}
func (peer *Peer) FlushNonceQueue() {
@@ -46,75 +51,87 @@ func (peer *Peer) FlushNonceQueue() {
}
}
-func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
+func (device *Device) NewOutboundElement() *QueueOutboundElement {
+ elem := new(QueueOutboundElement) // TODO: profile, consider sync.Pool
+ return elem
+}
+
+func (elem *QueueOutboundElement) Drop() {
+ atomic.StoreUint32(&elem.state, ElementStateDropped)
+}
+
+func (elem *QueueOutboundElement) IsDropped() bool {
+ return atomic.LoadUint32(&elem.state) == ElementStateDropped
+}
+
+func addToOutboundQueue(
+ queue chan *QueueOutboundElement,
+ element *QueueOutboundElement,
+) {
for {
select {
- case peer.queue.outbound <- elem:
+ case queue <- element:
return
default:
select {
- case <-peer.queue.outbound:
+ case old := <-queue:
+ old.Drop()
default:
}
}
}
}
-func (elem *QueueOutboundElement) Drop() {
- atomic.StoreUint32(&elem.state, ElementStateDropped)
-}
-
-func (elem *QueueOutboundElement) IsDropped() bool {
- return atomic.LoadUint32(&elem.state) == ElementStateDropped
-}
-
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
- if tun.MTU() == 0 {
- // Dummy
+ if tun == nil {
+ // dummy
return
}
+ elem := device.NewOutboundElement()
+
device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
- packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
- size, err := tun.Read(packet)
+ if elem == nil {
+ elem = device.NewOutboundElement()
+ }
+
+ elem.packet = elem.data[MessageTransportHeaderSize:]
+ size, err := tun.Read(elem.packet)
if err != nil {
device.log.Error.Println("Failed to read packet from TUN device:", err)
continue
}
- packet = packet[:size]
- if len(packet) < IPv4headerSize {
- device.log.Error.Println("Packet too short, length:", len(packet))
+ elem.packet = elem.packet[:size]
+ if len(elem.packet) < IPv4headerSize {
+ device.log.Error.Println("Packet too short, length:", size)
continue
}
// lookup peer
var peer *Peer
- switch packet[0] >> 4 {
+ switch elem.packet[0] >> 4 {
case IPv4version:
- dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
- device.log.Debug.Println("New IPv4 packet:", packet, dst)
case IPv6version:
- dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
- device.log.Debug.Println("New IPv6 packet:", packet, dst)
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
}
if peer == nil {
- device.log.Debug.Println("No peer configured for IP")
continue
}
if peer.endpoint == nil {
@@ -124,18 +141,9 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
// insert into nonce/pre-handshake queue
- for {
- select {
- case peer.queue.nonce <- packet:
- default:
- select {
- case <-peer.queue.nonce:
- default:
- }
- continue
- }
- break
- }
+ addToOutboundQueue(peer.queue.nonce, elem)
+ elem = nil
+
}
}
@@ -148,8 +156,8 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
- var packet []byte
var keyPair *KeyPair
+ var elem *QueueOutboundElement
device := peer.device
logger := device.log.Debug
@@ -163,9 +171,9 @@ func (peer *Peer) RoutineNonce() {
// wait for packet
- if packet == nil {
+ if elem == nil {
select {
- case packet = <-peer.queue.nonce:
+ case elem = <-peer.queue.nonce:
case <-peer.signal.stop:
return
}
@@ -198,7 +206,7 @@ func (peer *Peer) RoutineNonce() {
case <-peer.signal.flushNonceQueue:
logger.Println("Clearing queue for peer", peer.id)
peer.FlushNonceQueue()
- packet = nil
+ elem = nil
goto NextPacket
case <-peer.signal.stop:
@@ -208,36 +216,20 @@ func (peer *Peer) RoutineNonce() {
// process current packet
- if packet != nil {
+ if elem != nil {
// create work element
- work := new(QueueOutboundElement) // TODO: profile, maybe use pool
- work.keyPair = keyPair
- work.packet = packet
- work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
- work.peer = peer
- work.mutex.Lock()
-
- packet = nil
-
- // drop packets until there is space
-
- func() {
- for {
- select {
- case peer.device.queue.encryption <- work:
- return
- default:
- select {
- case elem := <-peer.device.queue.encryption:
- elem.Drop()
- default:
- }
- }
- }
- }()
- peer.queue.outbound <- work
+ elem.keyPair = keyPair
+ elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+ elem.peer = peer
+ elem.mutex.Lock()
+
+ // add to parallel processing and sequential consuming queue
+
+ addToOutboundQueue(device.queue.encryption, elem)
+ addToOutboundQueue(peer.queue.outbound, elem)
+ elem = nil
}
}
}()
@@ -257,42 +249,38 @@ func (device *Device) RoutineEncryption() {
continue
}
- // pad packet
-
- padding := device.mtu - len(work.packet) - MessageTransportSize
- if padding < 0 {
- work.Drop()
- continue
- }
-
- for n := 0; n < padding; n += 1 {
- work.packet = append(work.packet, 0)
- }
- content := work.packet[MessageTransportHeaderSize:]
- copy(content, work.packet)
+ // populate header fields
- // prepare header
+ func() {
+ header := work.data[:MessageTransportHeaderSize]
- binary.LittleEndian.PutUint32(work.packet[:4], MessageTransportType)
- binary.LittleEndian.PutUint32(work.packet[4:8], work.keyPair.remoteIndex)
- binary.LittleEndian.PutUint64(work.packet[8:16], work.nonce)
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- device.log.Debug.Println(work.packet, work.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, work.nonce)
+ }()
// encrypt content
- binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
- work.keyPair.send.Seal(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- work.mutex.Unlock()
+ func() {
+ binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
+ work.packet = work.keyPair.send.Seal(
+ work.packet[:0],
+ nonce[:],
+ work.packet,
+ nil,
+ )
+ work.mutex.Unlock()
+ }()
- device.log.Debug.Println(work.packet, work.nonce)
+ // reslice to include header
- // initiate new handshake
+ work.packet = work.data[:MessageTransportHeaderSize+len(work.packet)]
+
+ // refresh key if necessary
work.peer.KeepKeyFreshSending()
}
@@ -340,8 +328,6 @@ func (peer *Peer) RoutineSequentialSender() {
return
}
- logger.Println(work.packet)
-
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
if err != nil {
return