aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-04 16:15:53 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-04 16:15:53 +0200
commit8c34c4cbb3780c433148966a004f5a51aace0f64 (patch)
treea590de76c326f6dfe3c92d2e27b78ce2ab792289 /src
parentMerge branch 'master' of git.zx2c4.com:wireguard-go (diff)
downloadwireguard-go-8c34c4cbb3780c433148966a004f5a51aace0f64.tar.xz
wireguard-go-8c34c4cbb3780c433148966a004f5a51aace0f64.zip
First set of code review patches
Diffstat (limited to '')
-rw-r--r--src/config.go225
-rw-r--r--src/constants.go3
-rw-r--r--src/device.go44
-rw-r--r--src/index.go10
-rw-r--r--src/macs.go15
-rw-r--r--src/noise_helpers.go8
-rw-r--r--src/noise_protocol.go9
-rw-r--r--src/noise_types.go22
-rw-r--r--src/receive.go44
-rw-r--r--src/send.go51
-rw-r--r--src/timers.go33
-rw-r--r--src/trie.go9
-rw-r--r--src/tun.go1
-rw-r--r--src/tun_linux.go6
-rw-r--r--src/uapi_linux.go13
15 files changed, 313 insertions, 180 deletions
diff --git a/src/config.go b/src/config.go
index 72a604f..e2d7f20 100644
--- a/src/config.go
+++ b/src/config.go
@@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("persistent_keepalive_interval=%d",
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
))
+
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
@@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logDebug := device.log.Debug
var peer *Peer
+
+ deviceConfig := true
+
for scanner.Scan() {
// parse line
@@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
parts := strings.Split(line, "=")
if len(parts) != 2 {
- return &IPCError{Code: ipcErrorNoKeyValue}
+ return &IPCError{Code: ipcErrorProtocol}
}
key := parts[0]
value := parts[1]
- switch key {
+ /* device configuration */
- /* interface configuration */
+ if deviceConfig {
- case "private_key":
- var sk NoisePrivateKey
- if value == "" {
- device.SetPrivateKey(sk)
- } else {
- err := sk.FromHex(value)
- if err != nil {
- logError.Println("Failed to set private_key:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
+ switch key {
+ case "private_key":
+ var sk NoisePrivateKey
+ if value == "" {
+ device.SetPrivateKey(sk)
+ } else {
+ err := sk.FromHex(value)
+ if err != nil {
+ logError.Println("Failed to set private_key:", err)
+ return &IPCError{Code: ipcErrorInvalid}
+ }
+ device.SetPrivateKey(sk)
}
- device.SetPrivateKey(sk)
- }
- case "listen_port":
- port, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
- }
- netc := &device.net
- netc.mutex.Lock()
- if netc.addr.Port != int(port) {
- if netc.conn != nil {
- netc.conn.Close()
+ case "listen_port":
+ port, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ logError.Println("Failed to set listen_port:", err)
+ return &IPCError{Code: ipcErrorInvalid}
}
- netc.addr.Port = int(port)
- netc.conn, err = net.ListenUDP("udp", netc.addr)
- }
- netc.mutex.Unlock()
- if err != nil {
- logError.Println("Failed to create UDP listener:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
- }
+ netc := &device.net
+ netc.mutex.Lock()
+ if netc.addr.Port != int(port) {
+ if netc.conn != nil {
+ netc.conn.Close()
+ }
+ netc.addr.Port = int(port)
+ netc.conn, err = net.ListenUDP("udp", netc.addr)
+ }
+ netc.mutex.Unlock()
+ if err != nil {
+ logError.Println("Failed to create UDP listener:", err)
+ return &IPCError{Code: ipcErrorIO}
+ }
+ // TODO: Clear source address of all peers
- case "fwmark":
- logError.Println("FWMark not handled yet")
+ case "fwmark":
+ logError.Println("FWMark not handled yet")
+ // TODO: Clear source address of all peers
- case "public_key":
- var pubKey NoisePublicKey
- err := pubKey.FromHex(value)
- if err != nil {
- logError.Println("Failed to get peer by public_key:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
- }
- device.mutex.RLock()
- peer, _ = device.peers[pubKey]
- device.mutex.RUnlock()
- if peer == nil {
- peer = device.NewPeer(pubKey)
- }
+ case "public_key":
- case "replace_peers":
- if value == "true" {
- device.RemoveAllPeers()
- } else {
- logError.Println("Failed to set replace_peers, invalid value:", value)
- return &IPCError{Code: ipcErrorInvalidValue}
- }
+ // switch to peer configuration
- default:
+ deviceConfig = false
- /* peer configuration */
+ case "replace_peers":
+ if value != "true" {
+ logError.Println("Failed to set replace_peers, invalid value:", value)
+ return &IPCError{Code: ipcErrorInvalid}
+ }
+ device.RemoveAllPeers()
- if peer == nil {
- logError.Println("No peer referenced, before peer operation")
- return &IPCError{Code: ipcErrorNoPeer}
+ default:
+ logError.Println("Invalid UAPI key (device configuration):", key)
+ return &IPCError{Code: ipcErrorInvalid}
}
+ }
+
+ /* peer configuration */
+
+ if !deviceConfig {
switch key {
+ case "public_key":
+ var pubKey NoisePublicKey
+ err := pubKey.FromHex(value)
+ if err != nil {
+ logError.Println("Failed to get peer by public_key:", err)
+ return &IPCError{Code: ipcErrorInvalid}
+ }
+
+ // check if public key of peer equal to device
+
+ device.mutex.RLock()
+ if device.publicKey.Equals(pubKey) {
+ device.mutex.RUnlock()
+ logError.Println("Public key of peer matches private key of device")
+ return &IPCError{Code: ipcErrorInvalid}
+ }
+
+ // find peer referenced
+
+ peer, _ = device.peers[pubKey]
+ device.mutex.RUnlock()
+ if peer == nil {
+ peer = device.NewPeer(pubKey)
+ }
+
case "remove":
+ if value != "true" {
+ logError.Println("Failed to set remove, invalid value:", value)
+ return &IPCError{Code: ipcErrorInvalid}
+ }
device.RemovePeer(peer.handshake.remoteStatic)
logDebug.Println("Removing", peer.String())
peer = nil
@@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}()
if err != nil {
logError.Println("Failed to set preshared_key:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
+ return &IPCError{Code: ipcErrorInvalid}
}
case "endpoint":
+ // TODO: Only IP and port
addr, err := net.ResolveUDPAddr("udp", value)
if err != nil {
logError.Println("Failed to set endpoint:", value)
- return &IPCError{Code: ipcErrorInvalidValue}
+ return &IPCError{Code: ipcErrorInvalid}
}
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
case "persistent_keepalive_interval":
- secs, err := strconv.ParseInt(value, 10, 64)
- if secs < 0 || err != nil {
+
+ // update keep-alive interval
+
+ secs, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
+ return &IPCError{Code: ipcErrorInvalid}
}
- atomic.StoreUint64(
+
+ old := atomic.SwapUint64(
&peer.persistentKeepaliveInterval,
- uint64(secs),
+ secs,
)
+ // send immediate keep-alive
+
+ if old == 0 && secs != 0 {
+ up, err := device.tun.IsUp()
+ if err != nil {
+ logError.Println("Failed to get tun device status:", err)
+ return &IPCError{Code: ipcErrorIO}
+ }
+ if up {
+ peer.SendKeepAlive()
+ }
+ }
+
case "replace_allowed_ips":
- if value == "true" {
- device.routingTable.RemovePeer(peer)
- } else {
+ if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
- return &IPCError{Code: ipcErrorInvalidValue}
+ return &IPCError{Code: ipcErrorInvalid}
}
+ device.routingTable.RemovePeer(peer)
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed_ip:", err)
- return &IPCError{Code: ipcErrorInvalidValue}
+ return &IPCError{Code: ipcErrorInvalid}
}
ones, _ := network.Mask.Size()
device.routingTable.Insert(network.IP, uint(ones), peer)
default:
- logError.Println("Invalid UAPI key:", key)
- return &IPCError{Code: ipcErrorInvalidKey}
+ logError.Println("Invalid UAPI key (peer configuration):", key)
+ return &IPCError{Code: ipcErrorInvalid}
}
}
}
@@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcHandle(device *Device, socket net.Conn) {
+ // create buffered read/writer
+
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) {
return
}
- switch op {
+ // handle operation
+ var status *IPCError
+
+ switch op {
case "set=1\n":
device.log.Debug.Println("Config, set operation")
- err := ipcSetOperation(device, buffered)
- if err != nil {
- fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
- }
- return
+ status = ipcSetOperation(device, buffered)
case "get=1\n":
device.log.Debug.Println("Config, get operation")
- err := ipcGetOperation(device, buffered)
- if err != nil {
- fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
- }
- return
+ status = ipcGetOperation(device, buffered)
default:
device.log.Error.Println("Invalid UAPI operation:", op)
+ return
+ }
+
+ // write status
+ if status != nil {
+ device.log.Error.Println(status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
}
}
diff --git a/src/constants.go b/src/constants.go
index 09d33d8..f09ded6 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -16,6 +16,7 @@ const (
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120
MaxHandshakeAttemptTime = time.Second * 90
+ PaddingMultiple = 16
)
const (
@@ -31,5 +32,5 @@ const (
QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize // size of keep-alive
- MaxMessageSize = (1 << 16) - 1
+ MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize
)
diff --git a/src/device.go b/src/device.go
index 1185d60..de96f0b 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,6 +1,8 @@
package main
import (
+ "errors"
+ "fmt"
"net"
"runtime"
"sync"
@@ -10,6 +12,7 @@ import (
type Device struct {
mtu int32
+ tun TUNDevice
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
@@ -43,24 +46,46 @@ type Device struct {
mac MACStateDevice
}
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock()
+ // check if public key is matching any peer
+
+ publicKey := sk.publicKey()
+ for _, peer := range device.peers {
+ h := &peer.handshake
+ h.mutex.RLock()
+ if h.remoteStatic.Equals(publicKey) {
+ h.mutex.RUnlock()
+ return errors.New("Private key matches public key of peer")
+ }
+ h.mutex.RUnlock()
+ }
+
// update key material
device.privateKey = sk
- device.publicKey = sk.publicKey()
- device.mac.Init(device.publicKey)
+ device.publicKey = publicKey
+ device.mac.Init(publicKey)
// do DH precomputations
+ isZero := device.privateKey.IsZero()
+
for _, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
- h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ if isZero {
+ h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+ } else {
+ h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ }
+ fmt.Println(h.precomputedStaticStatic)
h.mutex.Unlock()
}
+
+ return nil
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
@@ -77,6 +102,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.mutex.Lock()
defer device.mutex.Unlock()
+ device.tun = tun
device.log = NewLogger(logLevel)
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
@@ -119,22 +145,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineBusyMonitor()
- go device.RoutineMTUUpdater(tun)
- go device.RoutineWriteToTUN(tun)
- go device.RoutineReadFromTUN(tun)
+ go device.RoutineMTUUpdater()
+ go device.RoutineWriteToTUN()
+ go device.RoutineReadFromTUN()
go device.RoutineReceiveIncomming()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device
}
-func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
+func (device *Device) RoutineMTUUpdater() {
logError := device.log.Error
for ; ; time.Sleep(5 * time.Second) {
// load updated MTU
- mtu, err := tun.MTU()
+ mtu, err := device.tun.MTU()
if err != nil {
logError.Println("Failed to load updated MTU of device:", err)
continue
diff --git a/src/index.go b/src/index.go
index 44b4974..e518b0f 100644
--- a/src/index.go
+++ b/src/index.go
@@ -3,6 +3,7 @@ package main
import (
"crypto/rand"
"sync"
+ "unsafe"
)
/* Index=0 is reserved for unset indecies
@@ -23,14 +24,7 @@ type IndexTable struct {
func randUint32() (uint32, error) {
var buff [4]byte
_, err := rand.Read(buff[:])
- id := uint32(buff[0])
- id <<= 8
- id |= uint32(buff[1])
- id <<= 8
- id |= uint32(buff[2])
- id <<= 8
- id |= uint32(buff[3])
- return id, err
+ return *((*uint32)(unsafe.Pointer(&buff))), err
}
func (table *IndexTable) Init() {
diff --git a/src/macs.go b/src/macs.go
index 841ef31..beb5f76 100644
--- a/src/macs.go
+++ b/src/macs.go
@@ -3,7 +3,6 @@ package main
import (
"crypto/hmac"
"crypto/rand"
- "errors"
"golang.org/x/crypto/blake2s"
"net"
"sync"
@@ -15,14 +14,14 @@ type MACStateDevice struct {
refreshed time.Time
secret [blake2s.Size]byte
keyMAC1 [blake2s.Size]byte
- keyMAC2 [blake2s.Size]byte
+ keyMAC2 [blake2s.Size]byte // TODO: Change to more descriptive size constant, rename to something.
}
type MACStatePeer struct {
mutex sync.RWMutex
cookieSet time.Time
cookie [blake2s.Size128]byte
- lastMAC1 [blake2s.Size128]byte
+ lastMAC1 [blake2s.Size128]byte // TODO: Check if set
keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte
}
@@ -83,7 +82,7 @@ func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool {
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
mac, _ := blake2s.New128(state.secret[:])
mac.Write(addr.IP)
- mac.Write(port[:])
+ mac.Write(port[:]) // TODO: Be faster and more platform dependent?
mac.Sum(cookie[:0])
}()
@@ -130,7 +129,7 @@ func (device *Device) CreateMessageCookieReply(
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
mac, _ := blake2s.New128(state.secret[:])
mac.Write(addr.IP)
- mac.Write(port[:])
+ mac.Write(port[:]) // TODO: Do whatever we did above
mac.Sum(cookie[:0])
}()
@@ -196,6 +195,7 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
if err != nil {
return false
}
+
state.cookieSet = time.Now()
state.cookie = cookie
return true
@@ -229,10 +229,6 @@ func (state *MACStatePeer) Init(pk NoisePublicKey) {
func (state *MACStatePeer) AddMacs(msg []byte) {
size := len(msg)
- if size < blake2s.Size128*2 {
- panic(errors.New("bug: message too short"))
- }
-
startMac1 := size - (blake2s.Size128 * 2)
startMac2 := size - blake2s.Size128
@@ -250,6 +246,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
mac.Sum(mac1[:0])
}()
copy(state.lastMAC1[:], mac1)
+ // TODO: Set lastMac flag
// set mac2
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index 1e622a5..105f78f 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -47,6 +47,14 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
return
}
+func isZero(val []byte) bool {
+ var acc byte
+ for _, b := range val {
+ acc |= b
+ }
+ return acc == 0
+}
+
/* curve25519 wrappers */
func newPrivateKey() (sk NoisePrivateKey, err error) {
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index e2ff573..5c776a8 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -135,6 +135,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errors.New("Static shared secret is zero")
+ }
+
// create ephemeral key
var err error
@@ -226,7 +230,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if peer == nil {
return nil
}
+
handshake := &peer.handshake
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil
+ }
// verify identity
@@ -472,6 +480,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
func() {
kp.mutex.Lock()
defer kp.mutex.Unlock()
+ // TODO: Adapt kernel behavior noise.c:161
if isInitiator {
if kp.previous != nil {
kp.previous.send = nil
diff --git a/src/noise_types.go b/src/noise_types.go
index 5ebc130..1a944df 100644
--- a/src/noise_types.go
+++ b/src/noise_types.go
@@ -1,6 +1,7 @@
package main
import (
+ "crypto/subtle"
"encoding/hex"
"errors"
"golang.org/x/crypto/chacha20poly1305"
@@ -31,12 +32,12 @@ func loadExactHex(dst []byte, src string) error {
}
func (key NoisePrivateKey) IsZero() bool {
- for _, b := range key[:] {
- if b != 0 {
- return false
- }
- }
- return true
+ var zero NoisePrivateKey
+ return key.Equals(zero)
+}
+
+func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
+ return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoisePrivateKey) FromHex(src string) error {
@@ -55,6 +56,15 @@ func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:])
}
+func (key NoisePublicKey) IsZero() bool {
+ var zero NoisePublicKey
+ return key.Equals(zero)
+}
+
+func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
+ return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
func (key *NoiseSymmetricKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
diff --git a/src/receive.go b/src/receive.go
index 700b894..fb5c51f 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -73,6 +73,8 @@ func (device *Device) addToHandshakeQueue(
}
/* Routine determining the busy state of the interface
+ *
+ * TODO: Under load for some time
*/
func (device *Device) RoutineBusyMonitor() {
samples := 0
@@ -131,6 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
buffer = device.GetMessageBuffer()
}
+ // TODO: Take writelock to sleep
device.net.mutex.RLock()
conn := device.net.conn
device.net.mutex.RUnlock()
@@ -139,6 +142,7 @@ func (device *Device) RoutineReceiveIncomming() {
continue
}
+ // TODO: Wait for new conn or message
conn.SetReadDeadline(time.Now().Add(time.Second))
size, raddr, err := conn.ReadFromUDP(buffer[:])
@@ -156,6 +160,8 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageInitiationType, MessageResponseType:
+ // TODO: Check size early
+
// add to handshake queue
device.addToHandshakeQueue(
@@ -171,6 +177,8 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageCookieReplyType:
+ // TODO: Queue all the things
+
// verify and update peer cookie state
if len(packet) != MessageCookieReplySize {
@@ -250,7 +258,7 @@ func (device *Device) RoutineDecryption() {
// check if dropped
if elem.IsDropped() {
- elem.mutex.Unlock()
+ elem.mutex.Unlock() // TODO: Make consistent with send
continue
}
@@ -318,6 +326,7 @@ func (device *Device) RoutineHandshake() {
logError.Println("Failed to create cookie reply:", err)
return
}
+ // TODO: Use temp
writer := bytes.NewBuffer(elem.packet[:0])
binary.Write(writer, binary.LittleEndian, reply)
elem.packet = writer.Bytes()
@@ -330,6 +339,8 @@ func (device *Device) RoutineHandshake() {
// ratelimit
+ // TODO: Only ratelimit when busy
+
if !device.ratelimiter.Allow(elem.source.IP) {
return
}
@@ -364,9 +375,14 @@ func (device *Device) RoutineHandshake() {
)
return
}
- peer.TimerPacketReceived()
+
+ // update timers
+
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint
+ // TODO: Add a race condition \s
peer.mutex.Lock()
peer.endpoint = elem.source
@@ -381,6 +397,7 @@ func (device *Device) RoutineHandshake() {
}
peer.TimerEphemeralKeyCreated()
+ peer.NewKeyPair()
logDebug.Println("Creating response message for", peer.String())
@@ -392,8 +409,7 @@ func (device *Device) RoutineHandshake() {
// send response
peer.SendBuffer(packet)
- peer.TimerPacketSent()
- peer.NewKeyPair()
+ peer.TimerAnyAuthenticatedPacketTraversal()
case MessageResponseType:
@@ -423,8 +439,14 @@ func (device *Device) RoutineHandshake() {
return
}
- peer.TimerPacketReceived()
+ // update timers
+
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
peer.TimerHandshakeComplete()
+
+ // derive key-pair
+
peer.NewKeyPair()
peer.SendKeepAlive()
@@ -467,8 +489,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
return
}
- peer.TimerPacketReceived()
- peer.TimerTransportReceived()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving()
// check if using new key-pair
@@ -504,6 +526,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
+ // TODO: check length of packet & NOT TOO SMALL either
elem.packet = elem.packet[:length]
// verify IPv4 source
@@ -525,6 +548,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
+ // TODO: check length of packet
elem.packet = elem.packet[:length]
// verify IPv6 source
@@ -542,11 +566,13 @@ func (peer *Peer) RoutineSequentialReceiver() {
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
device.addToInboundQueue(device.queue.inbound, elem)
+
+ // TODO: move TUN write into per peer routine
}()
}
}
-func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
+func (device *Device) RoutineWriteToTUN() {
logError := device.log.Error
logDebug := device.log.Debug
@@ -557,7 +583,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
case <-device.signal.stop:
return
case elem := <-device.queue.inbound:
- _, err := tun.Write(elem.packet)
+ _, err := device.tun.Write(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
diff --git a/src/send.go b/src/send.go
index 37078b9..fc35732 100644
--- a/src/send.go
+++ b/src/send.go
@@ -110,17 +110,19 @@ func addToEncryptionQueue(
}
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+ peer.device.net.mutex.RLock()
+ defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+
endpoint := peer.endpoint
- peer.mutex.RUnlock()
+ conn := peer.device.net.conn
+
if endpoint == nil {
return 0, ErrorNoEndpoint
}
- peer.device.net.mutex.RLock()
- conn := peer.device.net.conn
- peer.device.net.mutex.RUnlock()
if conn == nil {
return 0, ErrorNoConnection
}
@@ -133,13 +135,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
*
* Obs. Single instance per TUN device
*/
-func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+func (device *Device) RoutineReadFromTUN() {
- if tun == nil {
+ if device.tun == nil {
return
}
- elem := device.NewOutboundElement()
+ var elem *QueueOutboundElement
logDebug := device.log.Debug
logError := device.log.Error
@@ -153,32 +155,38 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
elem = device.NewOutboundElement()
}
+ // TODO: THIS!
elem.packet = elem.buffer[MessageTransportHeaderSize:]
- size, err := tun.Read(elem.packet)
+ size, err := device.tun.Read(elem.packet)
if err != nil {
-
- // stop process
-
logError.Println("Failed to read packet from TUN device:", err)
device.Close()
return
}
- elem.packet = elem.packet[:size]
- if len(elem.packet) < ipv4.HeaderLen {
- logError.Println("Packet too short, length:", size)
+ if size == 0 {
continue
}
+ println(size, err)
+
+ elem.packet = elem.packet[:size]
+
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case ipv4.Version:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
case ipv6.Version:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
@@ -190,10 +198,15 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
continue
}
+ // check if known endpoint
+
+ peer.mutex.RLock()
if peer.endpoint == nil {
+ peer.mutex.RUnlock()
logDebug.Println("No known endpoint for peer", peer.String())
continue
}
+ peer.mutex.RUnlock()
// insert into nonce/pre-handshake queue
@@ -334,8 +347,12 @@ func (device *Device) RoutineEncryption() {
// pad content to MTU size
mtu := int(atomic.LoadInt32(&device.mtu))
- for i := len(elem.packet); i < mtu; i++ {
- elem.packet = append(elem.packet, 0)
+ pad := len(elem.packet) % PaddingMultiple
+ if pad > 0 {
+ for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
+ elem.packet = append(elem.packet, 0)
+ }
+ // TODO: How good is this code
}
// encrypt content (append to header)
@@ -390,7 +407,7 @@ func (peer *Peer) RoutineSequentialSender() {
// update timers
- peer.TimerPacketSent()
+ peer.TimerAnyAuthenticatedPacketTraversal()
if len(elem.packet) != MessageKeepaliveSize {
peer.TimerDataSent()
}
diff --git a/src/timers.go b/src/timers.go
index 5a16e9b..1be85f0 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -60,10 +60,8 @@ func (peer *Peer) SendKeepAlive() bool {
return true
}
-/* Authenticated data packet send
- * Always called together with peer.EventPacketSend
- *
- * - Start new handshake timer
+/* Event:
+ * Sent non-empty (authenticated) transport message
*/
func (peer *Peer) TimerDataSent() {
timerStop(peer.timer.keepalivePassive)
@@ -75,8 +73,6 @@ func (peer *Peer) TimerDataSent() {
/* Event:
* Received non-empty (authenticated) transport message
- *
- * - Start passive keep-alive timer
*/
func (peer *Peer) TimerDataReceived() {
if peer.timer.pendingKeepalivePassive {
@@ -88,17 +84,16 @@ func (peer *Peer) TimerDataReceived() {
}
/* Event:
- * Any (authenticated) transport message received
- * (keep-alive or data)
+ * Any (authenticated) packet received
*/
-func (peer *Peer) TimerTransportReceived() {
+func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
timerStop(peer.timer.newHandshake)
}
/* Event:
- * Any packet send to the peer.
+ * Any authenticated packet send / received.
*/
-func (peer *Peer) TimerPacketSent() {
+func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
duration := time.Duration(interval) * time.Second
@@ -106,13 +101,6 @@ func (peer *Peer) TimerPacketSent() {
}
}
-/* Event:
- * Any authenticated packet received from peer
- */
-func (peer *Peer) TimerPacketReceived() {
- peer.TimerPacketSent()
-}
-
/* Called after succesfully completing a handshake.
* i.e. after:
*
@@ -129,7 +117,9 @@ func (peer *Peer) TimerHandshakeComplete() {
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
-/* Called whenever an ephemeral key is generated
+/* Event:
+ * An ephemeral key is generated
+ *
* i.e after:
*
* CreateMessageInitiation
@@ -257,7 +247,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
select {
case <-peer.signal.handshakeBegin:
- signalSend(peer.signal.handshakeBegin)
case <-peer.signal.stop:
return
}
@@ -303,7 +292,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
- peer.TimerPacketSent()
_, err = peer.SendBuffer(packet)
if err != nil {
@@ -314,6 +302,8 @@ func (peer *Peer) RoutineHandshakeInitiator() {
continue
}
+ peer.TimerAnyAuthenticatedPacketTraversal()
+
// set timeout
timeout := time.NewTimer(RekeyTimeout)
@@ -337,7 +327,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
continue
}
-
}
// allow new signal to be set
diff --git a/src/trie.go b/src/trie.go
index e81b5b6..aa96a8a 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -32,11 +32,14 @@ type Trie struct {
/* Finds length of matching prefix
* TODO: Make faster
*
- * Assumption: len(ip1) == len(ip2)
+ * Assumption:
+ * len(ip1) == len(ip2)
+ * len(ip1) mod 4 = 0
*/
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1 []byte, ip2 []byte) uint {
var i uint
- size := uint(len(ip1))
+ size := uint(len(ip1)) / 4
+
for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i]
if v != 0 {
diff --git a/src/tun.go b/src/tun.go
index f529c54..d782bd5 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -9,6 +9,7 @@ const DefaultMTU = 1420
type TUNDevice interface {
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
+ IsUp() (bool, error) // is the interface up?
MTU() (int, error) // returns the MTU of the device
Name() string // returns the current name
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index 261d142..d0e2f47 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -7,6 +7,7 @@ import (
"encoding/binary"
"errors"
"golang.org/x/sys/unix"
+ "net"
"os"
"strings"
"unsafe"
@@ -19,6 +20,11 @@ type NativeTun struct {
name string
}
+func (tun *NativeTun) IsUp() (bool, error) {
+ inter, err := net.InterfaceByName(tun.name)
+ return inter.Flags&net.FlagUp != 0, err
+}
+
func (tun *NativeTun) Name() string {
return tun.name
}
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
index fd83918..d6d78e7 100644
--- a/src/uapi_linux.go
+++ b/src/uapi_linux.go
@@ -11,13 +11,12 @@ import (
)
const (
- ipcErrorIO = int64(unix.EIO)
- ipcErrorNoPeer = int64(unix.EPROTO)
- ipcErrorNoKeyValue = int64(unix.EPROTO)
- ipcErrorInvalidKey = int64(unix.EPROTO)
- ipcErrorInvalidValue = int64(unix.EPROTO)
- socketDirectory = "/var/run/wireguard"
- socketName = "%s.sock"
+ ipcErrorIO = -int64(unix.EIO)
+ ipcErrorNotDefined = -int64(unix.ENODEV)
+ ipcErrorProtocol = -int64(unix.EPROTO)
+ ipcErrorInvalid = -int64(unix.EINVAL)
+ socketDirectory = "/var/run/wireguard"
+ socketName = "%s.sock"
)
/* TODO: