aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-17 16:16:18 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-07-17 16:16:18 +0200
commitc5d7efc2467abb6cd8365c83fae68da6924c17f2 (patch)
tree0324219cf4979a87fc45fc575e26f7058b0a196f /src
parentAdded padding (diff)
downloadwireguard-go-c5d7efc2467abb6cd8365c83fae68da6924c17f2.tar.xz
wireguard-go-c5d7efc2467abb6cd8365c83fae68da6924c17f2.zip
Fixed deadlock in index.go
Diffstat (limited to 'src')
-rw-r--r--src/config.go160
-rw-r--r--src/device.go12
-rw-r--r--src/index.go6
-rw-r--r--src/main.go20
-rw-r--r--src/noise_protocol.go5
-rw-r--r--src/receive.go8
-rw-r--r--src/send.go81
-rw-r--r--src/timers.go52
8 files changed, 193 insertions, 151 deletions
diff --git a/src/config.go b/src/config.go
index 4edaa2e..d92e8d7 100644
--- a/src/config.go
+++ b/src/config.go
@@ -8,39 +8,36 @@ import (
"net"
"strconv"
"strings"
+ "sync/atomic"
+ "syscall"
)
-// #include <errno.h>
-import "C"
-
-/* TODO: More fine grained?
- */
const (
- ipcErrorNoPeer = C.EPROTO
- ipcErrorNoKeyValue = C.EPROTO
- ipcErrorInvalidKey = C.EPROTO
- ipcErrorInvalidValue = C.EPROTO
+ ipcErrorIO = syscall.EIO
+ ipcErrorNoPeer = syscall.EPROTO
+ ipcErrorNoKeyValue = syscall.EPROTO
+ ipcErrorInvalidKey = syscall.EPROTO
+ ipcErrorInvalidValue = syscall.EPROTO
)
type IPCError struct {
- Code int
+ Code syscall.Errno
}
func (s *IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.Code)
}
-func (s *IPCError) ErrorCode() int {
- return s.Code
+func (s *IPCError) ErrorCode() uintptr {
+ return uintptr(s.Code)
}
-func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
-
- device.mutex.RLock()
- defer device.mutex.RUnlock()
+func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// create lines
+ device.mutex.RLock()
+
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
@@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
}
send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
- send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+ send(fmt.Sprintf("persistent_keepalive_interval=%d",
+ atomic.LoadUint64(&peer.persistentKeepaliveInterval),
+ ))
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
}()
}
+ device.mutex.RUnlock()
+
// send lines
for _, line := range lines {
_, err := socket.WriteString(line + "\n")
if err != nil {
- return err
+ return &IPCError{
+ Code: ipcErrorIO,
+ }
}
}
@@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
- logger := device.log.Debug
scanner := bufio.NewScanner(socket)
+ logError := device.log.Error
+ logDebug := device.log.Debug
var peer *Peer
for scanner.Scan() {
- // Parse line
+ // parse line
line := scanner.Text()
if line == "" {
@@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
parts := strings.Split(line, "=")
if len(parts) != 2 {
- device.log.Debug.Println(parts)
return &IPCError{Code: ipcErrorNoKeyValue}
}
key := parts[0]
@@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key {
- /* Interface configuration */
+ /* interface configuration */
case "private_key":
if value == "" {
@@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
var sk NoisePrivateKey
err := sk.FromHex(value)
if err != nil {
- logger.Println("Failed to set private_key:", err)
+ logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.SetPrivateKey(sk)
@@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
var port int
_, err := fmt.Sscanf(value, "%d", &port)
if err != nil || port > (1<<16) || port < 0 {
- logger.Println("Failed to set listen_port:", err)
+ logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.net.mutex.Lock()
device.net.addr.Port = port
device.net.conn, err = net.ListenUDP("udp", device.net.addr)
device.net.mutex.Unlock()
+ if err != nil {
+ logError.Println("Failed to create UDP listener:", err)
+ return &IPCError{Code: ipcErrorInvalidValue}
+ }
case "fwmark":
- logger.Println("FWMark not handled yet")
+ logError.Println("FWMark not handled yet")
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
- logger.Println("Failed to get peer by public_key:", err)
+ logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.RLock()
@@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer = device.NewPeer(pubKey)
}
if peer == nil {
- panic(errors.New("bug: failed to find peer"))
+ panic(errors.New("bug: failed to find / create peer"))
}
case "replace_peers":
if value == "true" {
device.RemoveAllPeers()
} else {
- logger.Println("Failed to set replace_peers, invalid value:", value)
+ logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
default:
- /* Peer configuration */
+
+ /* peer configuration */
if peer == nil {
- logger.Println("No peer referenced, before peer operation")
+ logError.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer}
}
@@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Lock()
device.RemovePeer(peer.handshake.remoteStatic)
peer.mutex.Unlock()
- logger.Println("Remove peer")
+ logDebug.Println("Removing", peer.String())
peer = nil
case "preshared_key":
@@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return peer.handshake.presharedKey.FromHex(value)
}()
if err != nil {
- logger.Println("Failed to set preshared_key:", err)
+ logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "endpoint":
addr, err := net.ResolveUDPAddr("udp", value)
if err != nil {
- logger.Println("Failed to set endpoint:", value)
+ logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
@@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil {
- logger.Println("Failed to set persistent_keepalive_interval:", err)
+ logError.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
- peer.mutex.Lock()
- peer.persistentKeepaliveInterval = uint64(secs)
- peer.mutex.Unlock()
+ atomic.StoreUint64(
+ &peer.persistentKeepaliveInterval,
+ uint64(secs),
+ )
case "replace_allowed_ips":
if value == "true" {
device.routingTable.RemovePeer(peer)
} else {
- logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
+ logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
- logger.Println("Failed to set allowed_ip:", err)
+ logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
ones, _ := network.Mask.Size()
- logger.Println(network, ones, network.IP)
+ logError.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer)
- /* Invalid key */
-
default:
- logger.Println("Invalid key:", key)
+ logError.Println("Invalid UAPI key:", key)
return &IPCError{Code: ipcErrorInvalidKey}
}
}
@@ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcHandle(device *Device, socket net.Conn) {
- func() {
- buffered := func(s io.ReadWriter) *bufio.ReadWriter {
- reader := bufio.NewReader(s)
- writer := bufio.NewWriter(s)
- return bufio.NewReadWriter(reader, writer)
- }(socket)
+ defer socket.Close()
- defer buffered.Flush()
+ buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+ reader := bufio.NewReader(s)
+ writer := bufio.NewWriter(s)
+ return bufio.NewReadWriter(reader, writer)
+ }(socket)
- op, err := buffered.ReadString('\n')
- if err != nil {
- return
- }
+ defer buffered.Flush()
- switch op {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
- case "set=1\n":
- device.log.Debug.Println("Config, set operation")
- err := ipcSetOperation(device, buffered)
- if err != nil {
- fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
- }
- break
+ switch op {
- case "get=1\n":
- device.log.Debug.Println("Config, get operation")
- err := ipcGetOperation(device, buffered)
- if err != nil {
- fmt.Fprintf(buffered, "errno=1\n\n") // fix
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
- }
- break
+ case "set=1\n":
+ device.log.Debug.Println("Config, set operation")
+ err := ipcSetOperation(device, buffered)
+ if err != nil {
+ fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ return
- default:
- device.log.Info.Println("Invalid UAPI operation:", op)
+ case "get=1\n":
+ device.log.Debug.Println("Config, get operation")
+ err := ipcGetOperation(device, buffered)
+ if err != nil {
+ fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
}
- }()
+ return
+
+ default:
+ device.log.Error.Println("Invalid UAPI operation:", op)
- socket.Close()
+ }
}
diff --git a/src/device.go b/src/device.go
index 4981f51..d32d648 100644
--- a/src/device.go
+++ b/src/device.go
@@ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
defer device.mutex.Unlock()
device.log = NewLogger(logLevel)
- // device.mtu = tun.MTU()
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.ratelimiter.Init()
@@ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
logError := device.log.Error
- for ; ; time.Sleep(time.Second) {
+ for ; ; time.Sleep(5 * time.Second) {
+
+ // load updated MTU
+
mtu, err := tun.MTU()
if err != nil {
logError.Println("Failed to load updated MTU of device:", err)
continue
}
+
+ // upper bound of mtu
+
+ if mtu+MessageTransportSize > MaxMessageSize {
+ mtu = MaxMessageSize - MessageTransportSize
+ }
atomic.StoreInt32(&device.mtu, int32(mtu))
}
}
diff --git a/src/index.go b/src/index.go
index 59e2079..44b4974 100644
--- a/src/index.go
+++ b/src/index.go
@@ -7,8 +7,6 @@ import (
/* Index=0 is reserved for unset indecies
*
- * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
- *
*/
type IndexTableEntry struct {
@@ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.RLock()
_, ok := table.table[index]
+ table.mutex.RUnlock()
if ok {
continue
}
- table.mutex.RUnlock()
- // replace index
+ // map index to handshake
table.mutex.Lock()
_, found := table.table[index]
diff --git a/src/main.go b/src/main.go
index 74e7ec9..4bece16 100644
--- a/src/main.go
+++ b/src/main.go
@@ -17,12 +17,14 @@ func main() {
}
switch os.Args[1] {
+
case "-f", "--foreground":
foreground = true
if len(os.Args) != 3 {
return
}
interfaceName = os.Args[2]
+
default:
foreground = false
if len(os.Args) != 2 {
@@ -48,8 +50,8 @@ func main() {
// open TUN device
tun, err := CreateTUN(interfaceName)
- log.Println(tun, err)
if err != nil {
+ log.Println("Failed to create tun device:", err)
return
}
@@ -69,11 +71,15 @@ func main() {
}
defer uapi.Close()
- for {
- conn, err := uapi.Accept()
- if err != nil {
- logError.Fatal("accept error:", err)
+ go func() {
+ for {
+ conn, err := uapi.Accept()
+ if err != nil {
+ logError.Fatal("UAPI accept error:", err)
+ }
+ go ipcHandle(device, conn)
}
- go ipcHandle(device, conn)
- }
+ }()
+
+ device.Wait()
}
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index bfa3797..5fe6fb2 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// remap index
- peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
+ indices := &peer.device.indices
+ indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
keyPair: keyPair,
handshake: nil,
@@ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
if kp.previous != nil {
kp.previous.send = nil
kp.previous.receive = nil
- peer.device.indices.Delete(kp.previous.localIndex)
+ indices.Delete(kp.previous.localIndex)
}
kp.previous = kp.current
kp.current = keyPair
diff --git a/src/receive.go b/src/receive.go
index 31f74e2..e063c99 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() {
// add to peer queue
peer := value.peer
- work := &QueueInboundElement{
+ elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
keyPair: keyPair,
dropped: AtomicFalse,
}
- work.mutex.Lock()
+ elem.mutex.Lock()
// add to decryption queues
- device.addToInboundQueue(device.queue.decryption, work)
- device.addToInboundQueue(peer.queue.inbound, work)
+ device.addToInboundQueue(device.queue.decryption, elem)
+ device.addToInboundQueue(peer.queue.inbound, elem)
buffer = nil
default:
diff --git a/src/send.go b/src/send.go
index 2db74ba..fdbc676 100644
--- a/src/send.go
+++ b/src/send.go
@@ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() {
* Obs. One instance per core
*/
func (device *Device) RoutineEncryption() {
+
+ var elem *QueueOutboundElement
var nonce [chacha20poly1305.NonceSize]byte
- for work := range device.queue.encryption {
+
+ logDebug := device.log.Debug
+ logDebug.Println("Routine, encryption worker, started")
+
+ for {
+
+ // fetch next element
+
+ select {
+ case elem = <-device.queue.encryption:
+ case <-device.signal.stop:
+ logDebug.Println("Routine, encryption worker, stopped")
+ return
+ }
// check if dropped
- if work.IsDropped() {
+ if elem.IsDropped() {
continue
}
// populate header fields
- header := work.buffer[:MessageTransportHeaderSize]
+ header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, work.nonce)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to MTU size
mtu := int(atomic.LoadInt32(&device.mtu))
- for i := len(work.packet); i < mtu; i++ {
- work.packet = append(work.packet, 0)
+ for i := len(elem.packet); i < mtu; i++ {
+ elem.packet = append(elem.packet, 0)
}
// encrypt content
- binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
- work.packet = work.keyPair.send.Seal(
- work.packet[:0],
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keyPair.send.Seal(
+ elem.packet[:0],
nonce[:],
- work.packet,
+ elem.packet,
nil,
)
- length := MessageTransportHeaderSize + len(work.packet)
- work.packet = work.buffer[:length]
- work.mutex.Unlock()
+ length := MessageTransportHeaderSize + len(elem.packet)
+ elem.packet = elem.buffer[:length]
+ elem.mutex.Unlock()
// refresh key if necessary
- work.peer.KeepKeyFreshSending()
+ elem.peer.KeepKeyFreshSending()
}
}
@@ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() {
logDebug.Println("Routine, sequential sender, stopped for", peer.String())
return
- case work := <-peer.queue.outbound:
- work.mutex.Lock()
+ case elem := <-peer.queue.outbound:
+ elem.mutex.Lock()
func() {
-
- // return buffer to pool after processing
-
- defer device.PutMessageBuffer(work.buffer)
- if work.IsDropped() {
+ if elem.IsDropped() {
return
}
- // send to endpoint
+ // get endpoint and connection
peer.mutex.RLock()
- defer peer.mutex.RUnlock()
-
- if peer.endpoint == nil {
+ endpoint := peer.endpoint
+ peer.mutex.RUnlock()
+ if endpoint == nil {
logDebug.Println("No endpoint for", peer.String())
return
}
device.net.mutex.RLock()
- defer device.net.mutex.RUnlock()
-
- if device.net.conn == nil {
+ conn := device.net.conn
+ device.net.mutex.RUnlock()
+ if conn == nil {
logDebug.Println("No source for device")
return
}
- // send message and return buffer to pool
+ // send message and refresh keys
- _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
+ _, err := conn.WriteToUDP(elem.packet, endpoint)
if err != nil {
return
}
-
- atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
-
- // reset keep-alive
-
+ atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet)))
peer.TimerResetKeepalive()
}()
+
+ device.PutMessageBuffer(elem.buffer)
}
}
}
diff --git a/src/timers.go b/src/timers.go
index 9140e41..fd2bdc3 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
func (peer *Peer) RoutineTimerHandler() {
device := peer.device
+ indices := &device.indices
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
@@ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() {
logDebug.Println("Clearing all key material for", peer.String())
- // zero out key pairs
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
- func() {
- kp := &peer.keyPairs
- kp.mutex.Lock()
- // best we can do is wait for GC :( ?
- kp.current = nil
- kp.previous = nil
- kp.next = nil
- kp.mutex.Unlock()
- }()
+ hs := &peer.handshake
+ hs.mutex.Lock()
+
+ // unmap local indecies
+
+ indices.mutex.Lock()
+ if kp.previous != nil {
+ delete(indices.table, kp.previous.localIndex)
+ }
+ if kp.current != nil {
+ delete(indices.table, kp.current.localIndex)
+ }
+ if kp.next != nil {
+ delete(indices.table, kp.next.localIndex)
+ }
+ delete(indices.table, hs.localIndex)
+ indices.mutex.Unlock()
+
+ // zero out key pairs (TODO: better than wait for GC)
+
+ kp.current = nil
+ kp.previous = nil
+ kp.next = nil
+ kp.mutex.Unlock()
// zero out handshake
- func() {
- hs := &peer.handshake
- hs.mutex.Lock()
- hs.localEphemeral = NoisePrivateKey{}
- hs.remoteEphemeral = NoisePublicKey{}
- hs.chainKey = [blake2s.Size]byte{}
- hs.hash = [blake2s.Size]byte{}
- hs.mutex.Unlock()
- }()
+ hs.localIndex = 0
+ hs.localEphemeral = NoisePrivateKey{}
+ hs.remoteEphemeral = NoisePublicKey{}
+ hs.chainKey = [blake2s.Size]byte{}
+ hs.hash = [blake2s.Size]byte{}
+ hs.mutex.Unlock()
}
}
}