aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/config.go30
-rw-r--r--src/constants.go18
-rw-r--r--src/device.go37
-rw-r--r--src/handshake.go241
-rw-r--r--src/helper_test.go4
-rw-r--r--src/index.go2
-rw-r--r--src/keypair.go25
-rw-r--r--src/logger.go23
-rw-r--r--src/macs_test.go6
-rw-r--r--src/main.go2
-rw-r--r--src/misc.go7
-rw-r--r--src/noise_helpers.go2
-rw-r--r--src/noise_protocol.go86
-rw-r--r--src/noise_test.go4
-rw-r--r--src/peer.go54
-rw-r--r--src/send.go206
-rw-r--r--src/tun_linux.go1
17 files changed, 476 insertions, 272 deletions
diff --git a/src/config.go b/src/config.go
index 2f8dc76..8281581 100644
--- a/src/config.go
+++ b/src/config.go
@@ -8,7 +8,6 @@ import (
"net"
"strconv"
"strings"
- "time"
)
// #include <errno.h>
@@ -51,9 +50,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
send("private_key=" + device.privateKey.ToHex())
}
- if device.address != nil {
- send(fmt.Sprintf("listen_port=%d", device.address.Port))
- }
+ send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
for _, peer := range device.peers {
func() {
@@ -106,7 +103,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
key := parts[0]
value := parts[1]
- logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key {
@@ -118,13 +114,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.privateKey = NoisePrivateKey{}
device.mutex.Unlock()
} else {
- device.mutex.Lock()
- err := device.privateKey.FromHex(value)
- device.mutex.Unlock()
+ var sk NoisePrivateKey
+ err := sk.FromHex(value)
if err != nil {
logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
+ device.SetPrivateKey(sk)
}
case "listen_port":
@@ -134,12 +130,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
- device.mutex.Lock()
- if device.address == nil {
- device.address = &net.UDPAddr{}
- }
- device.address.Port = port
- device.mutex.Unlock()
+ device.net.mutex.Lock()
+ device.net.addr.Port = port
+ device.net.conn, err = net.ListenUDP("udp", device.net.addr)
+ device.net.mutex.Unlock()
case "fwmark":
logger.Println("FWMark not handled yet")
@@ -200,13 +194,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "endpoint":
- ip := net.ParseIP(value)
- if ip == nil {
+ addr, err := net.ResolveUDPAddr("udp", value)
+ if err != nil {
logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
- // peer.endpoint = ip FIX
+ peer.endpoint = addr
peer.mutex.Unlock()
case "persistent_keepalive_interval":
@@ -216,7 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
- peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
+ peer.persistentKeepaliveInterval = uint64(secs)
peer.mutex.Unlock()
case "replace_allowed_ips":
diff --git a/src/constants.go b/src/constants.go
index e8cdd63..34217d2 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -5,15 +5,15 @@ import (
)
const (
- RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
- RekeyAfterTime = time.Second * 120
- RekeyAttemptTime = time.Second * 90
- RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
- RejectAfterTime = time.Second * 180
- RejectAfterMessage = (1 << 64) - (1 << 4) - 1
- KeepaliveTimeout = time.Second * 10
- CookieRefreshTime = time.Second * 2
- MaxHandshakeAttempTime = time.Second * 90
+ RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
+ RekeyAfterTime = time.Second * 120
+ RekeyAttemptTime = time.Second * 90
+ RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
+ RejectAfterTime = time.Second * 180
+ RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ KeepaliveTimeout = time.Second * 10
+ CookieRefreshTime = time.Second * 2
+ MaxHandshakeAttemptTime = time.Second * 90
)
const (
diff --git a/src/device.go b/src/device.go
index 52ac6a4..a33e923 100644
--- a/src/device.go
+++ b/src/device.go
@@ -7,16 +7,21 @@ import (
)
type Device struct {
- mtu int
- fwMark uint32
- address *net.UDPAddr // UDP source address
- conn *net.UDPConn // UDP "connection"
+ mtu int
+ log *Logger // collection of loggers for levels
+ idCounter uint // for assigning debug ids to peers
+ fwMark uint32
+ net struct {
+ // seperate for performance reasons
+ mutex sync.RWMutex
+ addr *net.UDPAddr // UDP source address
+ conn *net.UDPConn // UDP "connection"
+ }
mutex sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
routingTable RoutingTable
indices IndexTable
- log *Logger
queue struct {
encryption chan *QueueOutboundElement // parallel work queue
}
@@ -44,17 +49,29 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
}
}
-func NewDevice(tun TUNDevice) *Device {
+func NewDevice(tun TUNDevice, logLevel int) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
- device.log = NewLogger()
+ device.log = NewLogger(logLevel)
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.routingTable.Reset()
+ // listen
+
+ device.net.mutex.Lock()
+ device.net.conn, _ = net.ListenUDP("udp", device.net.addr)
+ addr := device.net.conn.LocalAddr()
+ device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
+ device.net.mutex.Unlock()
+
+ // create queues
+
+ device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
+
// start workers
for i := 0; i < runtime.NumCPU(); i += 1 {
@@ -92,5 +109,11 @@ func (device *Device) RemoveAllPeers() {
peer.mutex.Lock()
delete(device.peers, key)
peer.Close()
+ peer.mutex.Unlock()
}
}
+
+func (device *Device) Close() {
+ device.RemoveAllPeers()
+ close(device.queue.encryption)
+}
diff --git a/src/handshake.go b/src/handshake.go
index 238c339..8f8e2f9 100644
--- a/src/handshake.go
+++ b/src/handshake.go
@@ -24,91 +24,163 @@ func (peer *Peer) SendKeepAlive() bool {
return true
}
-func (peer *Peer) RoutineHandshakeInitiator() {
- var ongoing bool
- var begun time.Time
- var attempts uint
- var timeout time.Timer
-
- device := peer.device
- work := new(QueueOutboundElement)
- buffer := make([]byte, 0, 1024)
-
- queueHandshakeInitiation := func() error {
- work.mutex.Lock()
- defer work.mutex.Unlock()
+func StoppedTimer() *time.Timer {
+ timer := time.NewTimer(time.Hour)
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return timer
+}
- // create initiation
+/* Called when a new authenticated message has been send
+ *
+ * TODO: This might be done in a faster way
+ */
+func (peer *Peer) KeepKeyFreshSending() {
+ send := func() bool {
+ peer.keyPairs.mutex.RLock()
+ defer peer.keyPairs.mutex.RUnlock()
- msg, err := device.CreateMessageInitiation(peer)
- if err != nil {
- return err
+ kp := peer.keyPairs.current
+ if kp == nil {
+ return false
}
- // create "work" element
+ if !kp.isInitiator {
+ return false
+ }
- writer := bytes.NewBuffer(buffer[:0])
- binary.Write(writer, binary.LittleEndian, &msg)
- work.packet = writer.Bytes()
- peer.mac.AddMacs(work.packet)
- peer.InsertOutbound(work)
- return nil
+ nonce := atomic.LoadUint64(&kp.sendNonce)
+ if nonce > RekeyAfterMessages {
+ return true
+ }
+ return time.Now().Sub(kp.created) > RekeyAfterTime
+ }()
+ if send {
+ sendSignal(peer.signal.handshakeBegin)
}
+}
- for {
- select {
- case <-peer.signal.stopInitiator:
- return
-
- case <-peer.signal.newHandshake:
- if ongoing {
- continue
- }
-
- // create handshake
-
- err := queueHandshakeInitiation()
- if err != nil {
- device.log.Error.Println("Failed to create initiation message:", err)
- }
-
- // log when we began
-
- begun = time.Now()
- ongoing = true
- attempts = 0
- timeout.Reset(RekeyTimeout)
-
- case <-peer.timer.sendKeepalive.C:
-
- // active keep-alives
-
- peer.SendKeepAlive()
+/* This is the state machine for handshake initiation
+ *
+ * Associated with this routine is the signal "handshakeBegin"
+ * The routine will read from the "handshakeBegin" channel
+ * at most every RekeyTimeout or with exponential backoff
+ *
+ * Implements exponential backoff for retries
+ */
+func (peer *Peer) RoutineHandshakeInitiator() {
+ work := new(QueueOutboundElement)
+ device := peer.device
+ buffer := make([]byte, 1024)
+ logger := device.log.Debug
+ timeout := time.NewTimer(time.Hour)
- case <-peer.timer.handshakeTimeout.C:
+ logger.Println("Routine, handshake initator, started for peer", peer.id)
- // check if we can stop trying
+ func() {
+ for {
+ var attempts uint
+ var deadline time.Time
- if time.Now().Sub(begun) > MaxHandshakeAttempTime {
- peer.signal.flushNonceQueue <- true
- peer.timer.sendKeepalive.Stop()
- ongoing = false
- continue
+ select {
+ case <-peer.signal.handshakeBegin:
+ case <-peer.signal.stop:
+ return
}
- // otherwise, try again (exponental backoff)
-
- attempts += 1
- err := queueHandshakeInitiation()
- if err != nil {
- device.log.Error.Println("Failed to create initiation message:", err)
+ HandshakeLoop:
+ for run := true; run; {
+ // clear completed signal
+
+ select {
+ case <-peer.signal.handshakeCompleted:
+ case <-peer.signal.stop:
+ return
+ default:
+ }
+
+ // queue handshake
+
+ err := func() error {
+ work.mutex.Lock()
+ defer work.mutex.Unlock()
+
+ // create initiation
+
+ msg, err := device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal
+
+ writer := bytes.NewBuffer(buffer[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ work.packet = writer.Bytes()
+ peer.mac.AddMacs(work.packet)
+ peer.InsertOutbound(work)
+ return nil
+ }()
+ if err != nil {
+ device.log.Error.Println("Failed to create initiation message:", err)
+ break
+ }
+ if attempts == 0 {
+ deadline = time.Now().Add(MaxHandshakeAttemptTime)
+ }
+
+ // set timeout
+
+ if !timeout.Stop() {
+ select {
+ case <-timeout.C:
+ default:
+ }
+ }
+ timeout.Reset((1 << attempts) * RekeyTimeout)
+ attempts += 1
+ device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
+ time.Sleep(RekeyTimeout)
+
+ // wait for handshake or timeout
+
+ select {
+ case <-peer.signal.stop:
+ return
+
+ case <-peer.signal.handshakeCompleted:
+ break HandshakeLoop
+
+ default:
+ select {
+
+ case <-peer.signal.stop:
+ return
+
+ case <-peer.signal.handshakeCompleted:
+ break HandshakeLoop
+
+ case <-timeout.C:
+ nextTimeout := (1 << attempts) * RekeyTimeout
+ if deadline.Before(time.Now().Add(nextTimeout)) {
+ // we do not have time for another attempt
+ peer.signal.flushNonceQueue <- struct{}{}
+ if !peer.timer.sendKeepalive.Stop() {
+ <-peer.timer.sendKeepalive.C
+ }
+ break HandshakeLoop
+ }
+ }
+ }
}
- peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
}
- }
+ }()
+
+ logger.Println("Routine, handshake initator, stopped for peer", peer.id)
}
-/* Handles packets related to handshake
+/* Handles incomming packets related to handshake
*
*
*/
@@ -140,33 +212,12 @@ func (device *Device) HandshakeWorker(queue chan struct {
// check for cookie
case MessageCookieReplyType:
+ if len(elem.msg) != MessageCookieReplySize {
+ continue
+ }
- case MessageTransportType:
- }
-
- }
-}
-
-func (device *Device) KeepKeyFresh(peer *Peer) {
-
- send := func() bool {
- peer.keyPairs.mutex.RLock()
- defer peer.keyPairs.mutex.RUnlock()
-
- kp := peer.keyPairs.current
- if kp == nil {
- return false
- }
-
- nonce := atomic.LoadUint64(&kp.sendNonce)
- if nonce > RekeyAfterMessage {
- return true
+ default:
+ device.log.Error.Println("Invalid message type in handshake queue")
}
-
- return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
- }()
-
- if send {
-
}
}
diff --git a/src/helper_test.go b/src/helper_test.go
index 3a5c331..464292f 100644
--- a/src/helper_test.go
+++ b/src/helper_test.go
@@ -35,7 +35,7 @@ func (tun *DummyTUN) Read(d []byte) (int, error) {
func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN
- dummy.mtu = 1024
+ dummy.mtu = 0
dummy.packets = make(chan []byte, 100)
return &dummy, nil
}
@@ -58,7 +58,7 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
- device := NewDevice(tun)
+ device := NewDevice(tun, LogLevelError)
device.SetPrivateKey(sk)
return device
}
diff --git a/src/index.go b/src/index.go
index 9178510..59e2079 100644
--- a/src/index.go
+++ b/src/index.go
@@ -41,7 +41,7 @@ func (table *IndexTable) Init() {
table.mutex.Unlock()
}
-func (table *IndexTable) ClearIndex(index uint32) {
+func (table *IndexTable) Delete(index uint32) {
if index == 0 {
return
}
diff --git a/src/keypair.go b/src/keypair.go
index 0b029ce..0e845f7 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -13,20 +13,27 @@ type KeyPair struct {
sendNonce uint64
isInitiator bool
created time.Time
+ id uint32
}
type KeyPairs struct {
- mutex sync.RWMutex
- current *KeyPair
- previous *KeyPair
- next *KeyPair // not yet "confirmed by transport"
- newKeyPair chan bool // signals when "current" has been updated
+ mutex sync.RWMutex
+ current *KeyPair
+ previous *KeyPair
+ next *KeyPair // not yet "confirmed by transport"
}
-func (kp *KeyPairs) Init() {
- kp.mutex.Lock()
- kp.newKeyPair = make(chan bool, 5)
- kp.mutex.Unlock()
+/* Called during recieving to confirm the handshake
+ * was completed correctly
+ */
+func (kp *KeyPairs) Used(key *KeyPair) {
+ if key == kp.next {
+ kp.mutex.Lock()
+ kp.previous = kp.current
+ kp.current = key
+ kp.next = nil
+ kp.mutex.Unlock()
+ }
}
func (kp *KeyPairs) Current() *KeyPair {
diff --git a/src/logger.go b/src/logger.go
index 117fe5b..827f9e9 100644
--- a/src/logger.go
+++ b/src/logger.go
@@ -1,6 +1,8 @@
package main
import (
+ "io"
+ "io/ioutil"
"log"
"os"
)
@@ -17,17 +19,30 @@ type Logger struct {
Error *log.Logger
}
-func NewLogger() *Logger {
+func NewLogger(level int) *Logger {
+ output := os.Stdout
logger := new(Logger)
- logger.Debug = log.New(os.Stdout,
+
+ logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
+ if level >= LogLevelDebug {
+ return output, output, output
+ }
+ if level >= LogLevelInfo {
+ return output, output, ioutil.Discard
+ }
+ return output, ioutil.Discard, ioutil.Discard
+ }()
+
+ logger.Debug = log.New(logDebug,
"DEBUG: ",
log.Ldate|log.Ltime|log.Lshortfile,
)
- logger.Info = log.New(os.Stdout,
+
+ logger.Info = log.New(logInfo,
"INFO: ",
log.Ldate|log.Ltime|log.Lshortfile,
)
- logger.Error = log.New(os.Stdout,
+ logger.Error = log.New(logErr,
"ERROR: ",
log.Ldate|log.Ltime|log.Lshortfile,
)
diff --git a/src/macs_test.go b/src/macs_test.go
index fcb64ea..a2a6503 100644
--- a/src/macs_test.go
+++ b/src/macs_test.go
@@ -11,6 +11,9 @@ func TestMAC1(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
+ defer dev1.Close()
+ defer dev2.Close()
+
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@@ -40,6 +43,9 @@ func TestMACs(t *testing.T) {
device2 := randDevice(t)
device2.SetPrivateKey(sk2)
+ defer device1.Close()
+ defer device2.Close()
+
peer1 := device2.NewPeer(device1.privateKey.publicKey())
peer2 := device1.NewPeer(device2.privateKey.publicKey())
diff --git a/src/main.go b/src/main.go
index 9c76ff4..b89af17 100644
--- a/src/main.go
+++ b/src/main.go
@@ -28,7 +28,7 @@ func main() {
return
}
- device := NewDevice(tun)
+ device := NewDevice(tun, LogLevelDebug)
// Start configuration lister
diff --git a/src/misc.go b/src/misc.go
index e1244d6..2bcb148 100644
--- a/src/misc.go
+++ b/src/misc.go
@@ -6,3 +6,10 @@ func min(a uint, b uint) uint {
}
return a
}
+
+func sendSignal(c chan struct{}) {
+ select {
+ case c <- struct{}{}:
+ default:
+ }
+}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index e163ace..1e622a5 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -33,6 +33,7 @@ func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
+ prk = [blake2s.Size]byte{}
return
}
@@ -42,6 +43,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
HMAC(&t2, prk[:], append(t1[:], 0x3))
+ prk = [blake2s.Size]byte{}
return
}
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 46ceeda..a1a1c7b 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -31,8 +31,9 @@ const (
)
const (
- MessageInitiationSize = 148
- MessageResponseSize = 92
+ MessageInitiationSize = 148
+ MessageResponseSize = 92
+ MessageCookieReplySize = 64
)
/* Type is an 8-bit field, followed by 3 nul bytes,
@@ -91,16 +92,11 @@ type Handshake struct {
}
var (
- InitalChainKey [blake2s.Size]byte
- InitalHash [blake2s.Size]byte
- ZeroNonce [chacha20poly1305.NonceSize]byte
+ InitialChainKey [blake2s.Size]byte
+ InitialHash [blake2s.Size]byte
+ ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func init() {
- InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
- InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
-}
-
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
@@ -117,6 +113,13 @@ func (h *Handshake) mixKey(data []byte) {
h.chainKey = mixKey(h.chainKey, data)
}
+/* Do basic precomputations
+ */
+func init() {
+ InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
+ InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
+}
+
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -125,28 +128,30 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// create ephemeral key
var err error
- handshake.chainKey = InitalChainKey
- handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
+ handshake.hash = InitialHash
+ handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
- device.indices.ClearIndex(handshake.localIndex)
- handshake.localIndex, err = device.indices.NewIndex(peer)
-
// assign index
- var msg MessageInitiation
-
- msg.Type = MessageInitiationType
- msg.Ephemeral = handshake.localEphemeral.publicKey()
+ device.indices.Delete(handshake.localIndex)
+ handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
- msg.Sender = handshake.localIndex
+ handshake.mixHash(handshake.remoteStatic[:])
+
+ msg := MessageInitiation{
+ Type: MessageInitiationType,
+ Ephemeral: handshake.localEphemeral.publicKey(),
+ Sender: handshake.localIndex,
+ }
+
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -185,9 +190,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil
}
- hash := mixHash(InitalHash, device.publicKey[:])
+ hash := mixHash(InitialHash, device.publicKey[:])
hash = mixHash(hash, msg.Ephemeral[:])
- chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
+ chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
// decrypt static key
@@ -278,7 +283,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index
var err error
- device.indices.ClearIndex(handshake.localIndex)
+ device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
@@ -420,10 +425,15 @@ func (peer *Peer) NewKeyPair() *KeyPair {
return nil
}
- // create AEAD instances
+ // zero handshake
+
+ handshake.chainKey = [blake2s.Size]byte{}
+ handshake.localEphemeral = NoisePrivateKey{}
+ peer.handshake.state = HandshakeZeroed
- var keyPair KeyPair
+ // create AEAD instances
+ keyPair := new(KeyPair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
@@ -433,30 +443,32 @@ func (peer *Peer) NewKeyPair() *KeyPair {
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
- keyPair: &keyPair,
+ keyPair: keyPair,
handshake: nil,
})
handshake.localIndex = 0
+ // start timer for keypair
+
// rotate key pairs
+ kp := &peer.keyPairs
func() {
- kp := &peer.keyPairs
kp.mutex.Lock()
defer kp.mutex.Unlock()
if isInitiator {
- kp.previous = peer.keyPairs.current
- kp.current = &keyPair
- kp.newKeyPair <- true
+ if kp.previous != nil {
+ kp.previous.send = nil
+ kp.previous.recv = nil
+ peer.device.indices.Delete(kp.previous.id)
+ }
+ kp.previous = kp.current
+ kp.current = keyPair
+ sendSignal(peer.signal.newKeyPair)
} else {
- kp.next = &keyPair
+ kp.next = keyPair
}
}()
- // zero handshake
-
- handshake.chainKey = [blake2s.Size]byte{}
- handshake.localEphemeral = NoisePrivateKey{}
- peer.handshake.state = HandshakeZeroed
- return &keyPair
+ return keyPair
}
diff --git a/src/noise_test.go b/src/noise_test.go
index 02f6bf3..9b50ff3 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -25,10 +25,12 @@ func TestCurveWrappers(t *testing.T) {
}
func TestNoiseHandshake(t *testing.T) {
-
dev1 := randDevice(t)
dev2 := randDevice(t)
+ defer dev1.Close()
+ defer dev2.Close()
+
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
diff --git a/src/peer.go b/src/peer.go
index 21cad9d..e885cee 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -10,26 +10,29 @@ import (
const ()
type Peer struct {
+ id uint
mutex sync.RWMutex
endpoint *net.UDPAddr
- persistentKeepaliveInterval time.Duration // 0 = disabled
+ persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
device *Device
tx_bytes uint64
rx_bytes uint64
time struct {
- lastSend time.Time // last send message
+ lastSend time.Time // last send message
+ lastHandshake time.Time // last completed handshake
}
signal struct {
- newHandshake chan bool
- flushNonceQueue chan bool // empty queued packets
- stopSending chan bool // stop sending pipeline
- stopInitiator chan bool // stop initiator timer
+ newKeyPair chan struct{} // (size 1) : a new key pair was generated
+ handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
+ handshakeCompleted chan struct{} // (size 1) : handshake completed
+ flushNonceQueue chan struct{} // (size 1) : empty queued packets
+ stop chan struct{} // (size 0) : close to stop all goroutines for peer
}
timer struct {
- sendKeepalive time.Timer
- handshakeTimeout time.Timer
+ sendKeepalive *time.Timer
+ handshakeTimeout *time.Timer
}
queue struct {
nonce chan []byte // nonce / pre-handshake queue
@@ -39,25 +42,30 @@ type Peer struct {
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
- var peer Peer
-
// create peer
+ peer := new(Peer)
peer.mutex.Lock()
+ defer peer.mutex.Unlock()
peer.device = device
- peer.keyPairs.Init()
peer.mac.Init(pk)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
+ peer.timer.sendKeepalive = StoppedTimer()
- // map public key
+ // assign id for debugging
device.mutex.Lock()
+ peer.id = device.idCounter
+ device.idCounter += 1
+
+ // map public key
+
_, ok := device.peers[pk]
if ok {
panic(errors.New("bug: adding existing peer"))
}
- device.peers[pk] = &peer
+ device.peers[pk] = peer
device.mutex.Unlock()
// precompute DH
@@ -67,22 +75,24 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
- peer.mutex.Unlock()
- // start workers
+ // prepare signaling
+
+ peer.signal.stop = make(chan struct{})
+ peer.signal.newKeyPair = make(chan struct{}, 1)
+ peer.signal.handshakeBegin = make(chan struct{}, 1)
+ peer.signal.handshakeCompleted = make(chan struct{}, 1)
+ peer.signal.flushNonceQueue = make(chan struct{}, 1)
- peer.signal.stopSending = make(chan bool, 1)
- peer.signal.stopInitiator = make(chan bool, 1)
- peer.signal.newHandshake = make(chan bool, 1)
- peer.signal.flushNonceQueue = make(chan bool, 1)
+ // outbound pipeline
go peer.RoutineNonce()
go peer.RoutineHandshakeInitiator()
+ go peer.RoutineSequentialSender()
- return &peer
+ return peer
}
func (peer *Peer) Close() {
- peer.signal.stopSending <- true
- peer.signal.stopInitiator <- true
+ close(peer.signal.stop)
}
diff --git a/src/send.go b/src/send.go
index ab75750..d4f9342 100644
--- a/src/send.go
+++ b/src/send.go
@@ -5,6 +5,8 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
+ "sync/atomic"
+ "time"
)
/* Handles outbound flow
@@ -29,6 +31,7 @@ type QueueOutboundElement struct {
packet []byte
nonce uint64
keyPair *KeyPair
+ peer *Peer
}
func (peer *Peer) FlushNonceQueue() {
@@ -46,6 +49,7 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for {
select {
case peer.queue.outbound <- elem:
+ return
default:
select {
case <-peer.queue.outbound:
@@ -61,11 +65,15 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+ if tun.MTU() == 0 {
+ // Dummy
+ return
+ }
+
device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
- device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
@@ -94,13 +102,16 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
- return
}
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
continue
}
+ if peer.endpoint == nil {
+ device.log.Debug.Println("No known endpoint for peer", peer.id)
+ continue
+ }
// insert into nonce/pre-handshake queue
@@ -131,69 +142,95 @@ func (peer *Peer) RoutineNonce() {
var packet []byte
var keyPair *KeyPair
- for {
+ device := peer.device
+ logger := device.log.Debug
- // wait for packet
+ logger.Println("Routine, nonce worker, started for peer", peer.id)
- if packet == nil {
- select {
- case packet = <-peer.queue.nonce:
- case <-peer.signal.stopSending:
- close(peer.queue.outbound)
- return
+ func() {
+
+ for {
+ NextPacket:
+
+ // wait for packet
+
+ if packet == nil {
+ select {
+ case packet = <-peer.queue.nonce:
+ case <-peer.signal.stop:
+ return
+ }
}
- }
- // wait for key pair
+ // wait for key pair
+
+ for {
+ select {
+ case <-peer.signal.newKeyPair:
+ default:
+ }
- for keyPair == nil {
- peer.signal.newHandshake <- true
- select {
- case <-peer.keyPairs.newKeyPair:
keyPair = peer.keyPairs.Current()
- continue
- case <-peer.signal.flushNonceQueue:
- peer.FlushNonceQueue()
- packet = nil
- continue
- case <-peer.signal.stopSending:
- close(peer.queue.outbound)
- return
- }
- }
+ if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
+ if time.Now().Sub(keyPair.created) < RejectAfterTime {
+ break
+ }
+ }
- // process current packet
+ sendSignal(peer.signal.handshakeBegin)
+ logger.Println("Waiting for key-pair, peer", peer.id)
- if packet != nil {
+ select {
+ case <-peer.signal.newKeyPair:
+ logger.Println("Key-pair negotiated for peer", peer.id)
+ goto NextPacket
+
+ case <-peer.signal.flushNonceQueue:
+ logger.Println("Clearing queue for peer", peer.id)
+ peer.FlushNonceQueue()
+ packet = nil
+ goto NextPacket
+
+ case <-peer.signal.stop:
+ return
+ }
+ }
- // create work element
+ // process current packet
- work := new(QueueOutboundElement) // TODO: profile, maybe use pool
- work.keyPair = keyPair
- work.packet = packet
- work.nonce = keyPair.sendNonce
- work.mutex.Lock()
+ if packet != nil {
- packet = nil
- keyPair.sendNonce += 1
+ // create work element
- // drop packets until there is space
+ work := new(QueueOutboundElement) // TODO: profile, maybe use pool
+ work.keyPair = keyPair
+ work.packet = packet
+ work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1)
+ work.peer = peer
+ work.mutex.Lock()
- func() {
- for {
- select {
- case peer.device.queue.encryption <- work:
- return
- default:
- drop := <-peer.device.queue.encryption
- drop.packet = nil
- drop.mutex.Unlock()
+ packet = nil
+
+ // drop packets until there is space
+
+ func() {
+ for {
+ select {
+ case peer.device.queue.encryption <- work:
+ return
+ default:
+ drop := <-peer.device.queue.encryption
+ drop.packet = nil
+ drop.mutex.Unlock()
+ }
}
- }
- }()
- peer.queue.outbound <- work
+ }()
+ peer.queue.outbound <- work
+ }
}
- }
+ }()
+
+ logger.Println("Routine, nonce worker, stopped for peer", peer.id)
}
/* Encrypts the elements in the queue
@@ -227,6 +264,10 @@ func (device *Device) RoutineEncryption() {
nil,
)
work.mutex.Unlock()
+
+ // initiate new handshake
+
+ work.peer.KeepKeyFreshSending()
}
}
@@ -235,21 +276,54 @@ func (device *Device) RoutineEncryption() {
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
-func (peer *Peer) RoutineSequential() {
- for work := range peer.queue.outbound {
- work.mutex.Lock()
- func() {
- peer.mutex.RLock()
- defer peer.mutex.RUnlock()
- if work.packet == nil {
- return
- }
- if peer.endpoint == nil {
- return
- }
- peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
- peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
- }()
- work.mutex.Unlock()
+func (peer *Peer) RoutineSequentialSender() {
+ logger := peer.device.log.Debug
+ logger.Println("Routine, sequential sender, started for peer", peer.id)
+
+ device := peer.device
+
+ for {
+ select {
+ case <-peer.signal.stop:
+ logger.Println("Routine, sequential sender, stopped for peer", peer.id)
+ return
+ case work := <-peer.queue.outbound:
+ work.mutex.Lock()
+ func() {
+ if work.packet == nil {
+ return
+ }
+
+ peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+
+ if peer.endpoint == nil {
+ logger.Println("No endpoint for peer:", peer.id)
+ return
+ }
+
+ device.net.mutex.RLock()
+ defer device.net.mutex.RUnlock()
+
+ if device.net.conn == nil {
+ logger.Println("No source for device")
+ return
+ }
+
+ logger.Println("Sending packet for peer", peer.id, work.packet)
+
+ _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
+ logger.Println("SEND:", peer.endpoint, err)
+ atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
+
+ // shift keep-alive timer
+
+ if peer.persistentKeepaliveInterval != 0 {
+ interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second
+ peer.timer.sendKeepalive.Reset(interval)
+ }
+ }()
+ work.mutex.Unlock()
+ }
}
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index cbbcb70..db13fb0 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -74,5 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
return &NativeTun{
fd: fd,
name: newName,
+ mtu: 1500, // TODO: FIX
}, nil
}