aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile3
-rw-r--r--allowedips.go (renamed from trie.go)110
-rw-r--r--allowedips_rand_test.go (renamed from trie_rand_test.go)16
-rw-r--r--allowedips_test.go (renamed from trie_test.go)26
-rw-r--r--device.go4
-rw-r--r--keypair.go2
-rw-r--r--logger.go2
-rw-r--r--noise-helpers.go3
-rw-r--r--noise-types.go2
-rw-r--r--peer.go24
-rw-r--r--receive.go26
-rw-r--r--routing.go70
-rw-r--r--tun_darwin.go4
-rw-r--r--tun_linux.go1
-rw-r--r--tun_windows.go4
-rw-r--r--uapi.go6
16 files changed, 139 insertions, 164 deletions
diff --git a/Makefile b/Makefile
index 5b23ecc..77eaac9 100644
--- a/Makefile
+++ b/Makefile
@@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go)
clean:
rm -f wireguard-go
-cloc:
- cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go))
-
.PHONY: clean cloc
diff --git a/trie.go b/allowedips.go
index 03f0722..df53abf 100644
--- a/trie.go
+++ b/allowedips.go
@@ -8,21 +8,12 @@ package main
import (
"errors"
"net"
+ "sync"
)
-/* Binary trie
- *
- * The net.IPs used here are not formatted the
- * same way as those created by the "net" functions.
- * Here the IPs are slices of either 4 or 16 byte (not always 16)
- *
- * Synchronization done separately
- * See: routing.go
- */
-
-type Trie struct {
+type trieEntry struct {
cidr uint
- child [2]*Trie
+ child [2]*trieEntry
bits []byte
peer *Peer
@@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint {
return i * 8
}
-func (node *Trie) RemovePeer(p *Peer) *Trie {
+func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
- node.child[0] = node.child[0].RemovePeer(p)
- node.child[1] = node.child[1].RemovePeer(p)
+ node.child[0] = node.child[0].removeByPeer(p)
+ node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
@@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node.child[0]
}
-func (node *Trie) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
-func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
+func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf
if node == nil {
- return &Trie{
+ return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
@@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return node
}
bit := node.choose(ip)
- node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
+ node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
- newNode := &Trie{
+ newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
@@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// create new parent for node & newNode
- parent := &Trie{
+ parent := &trieEntry{
bits: ip,
peer: nil,
cidr: cidr,
@@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return parent
}
-func (node *Trie) Lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer {
return found
}
-func (node *Trie) Count() uint {
- if node == nil {
- return 0
- }
- l := node.child[0].Count()
- r := node.child[1].Count()
- return l + r
-}
-
-func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
+func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
@@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
- panic(errors.New("bug: unexpected address length"))
+ panic(errors.New("unexpected address length"))
}
results = append(results, mask)
}
- results = node.child[0].AllowedIPs(p, results)
- results = node.child[1].AllowedIPs(p, results)
+ results = node.child[0].entriesForPeer(p, results)
+ results = node.child[1].entriesForPeer(p, results)
return results
}
+
+type AllowedIPs struct {
+ IPv4 *trieEntry
+ IPv6 *trieEntry
+ mutex sync.RWMutex
+}
+
+func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+
+ allowed := make([]net.IPNet, 0, 10)
+ allowed = table.IPv4.entriesForPeer(peer, allowed)
+ allowed = table.IPv6.entriesForPeer(peer, allowed)
+ return allowed
+}
+
+func (table *AllowedIPs) Reset() {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = nil
+ table.IPv6 = nil
+}
+
+func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = table.IPv4.removeByPeer(peer)
+ table.IPv6 = table.IPv6.removeByPeer(peer)
+}
+
+func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ switch len(ip) {
+ case net.IPv6len:
+ table.IPv6 = table.IPv6.insert(ip, cidr, peer)
+ case net.IPv4len:
+ table.IPv4 = table.IPv4.insert(ip, cidr, peer)
+ default:
+ panic(errors.New("inserting unknown address type"))
+ }
+}
+
+func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv4.lookup(address)
+}
+
+func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv6.lookup(address)
+}
diff --git a/trie_rand_test.go b/allowedips_rand_test.go
index 157c270..6ec039d 100644
--- a/trie_rand_test.go
+++ b/allowedips_rand_test.go
@@ -65,7 +65,7 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
}
func TestTrieRandomIPv4(t *testing.T) {
- var trie *Trie
+ var trie *trieEntry
var slow SlowRouter
var peers []*Peer
@@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
- trie = trie.Insert(addr[:], cidr, peers[index])
+ trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
@@ -90,15 +90,15 @@ func TestTrieRandomIPv4(t *testing.T) {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
- peer2 := trie.Lookup(addr[:])
+ peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ t.Error("trieEntry did not match naive implementation, for:", addr)
}
}
}
func TestTrieRandomIPv6(t *testing.T) {
- var trie *Trie
+ var trie *trieEntry
var slow SlowRouter
var peers []*Peer
@@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
- trie = trie.Insert(addr[:], cidr, peers[index])
+ trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
@@ -123,9 +123,9 @@ func TestTrieRandomIPv6(t *testing.T) {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
- peer2 := trie.Lookup(addr[:])
+ peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ t.Error("trieEntry did not match naive implementation, for:", addr)
}
}
}
diff --git a/trie_test.go b/allowedips_test.go
index 3c3b5ba..7b73af3 100644
--- a/trie_test.go
+++ b/allowedips_test.go
@@ -31,7 +31,7 @@ type testPairTrieLookup struct {
peer *Peer
}
-func printTrie(t *testing.T, p *Trie) {
+func printTrie(t *testing.T, p *trieEntry) {
if p == nil {
return
}
@@ -63,7 +63,7 @@ func TestCommonBits(t *testing.T) {
}
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
- var trie *Trie
+ var trie *trieEntry
var peers []*Peer
rand.Seed(1)
@@ -79,13 +79,13 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
- trie = trie.Insert(addr[:], cidr, peers[index])
+ trie = trie.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
- trie.Lookup(addr[:])
+ trie.lookup(addr[:])
}
}
@@ -117,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *Trie
+ var trie *trieEntry
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
- trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
+ trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.Lookup([]byte{a, b, c, d})
+ p := trie.lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.Lookup([]byte{a, b, c, d})
+ p := trie.lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -173,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.RemovePeer(a)
+ trie = trie.removeByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -186,7 +186,7 @@ func TestTrieIPv4(t *testing.T) {
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.RemovePeer(a)
+ trie = trie.removeByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
@@ -204,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *Trie
+ var trie *trieEntry
expand := func(a uint32) []byte {
var out [4]byte
@@ -221,7 +221,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.Insert(addr, cidr, peer)
+ trie = trie.insert(addr, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -230,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.Lookup(addr)
+ p := trie.lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
diff --git a/device.go b/device.go
index 99e451e..34af419 100644
--- a/device.go
+++ b/device.go
@@ -46,7 +46,7 @@ type Device struct {
routing struct {
mutex sync.RWMutex
- table RoutingTable
+ table AllowedIPs
}
peers struct {
@@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
- device.routing.table.RemovePeer(peer)
+ device.routing.table.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
diff --git a/keypair.go b/keypair.go
index 6f6f7c0..ea72a11 100644
--- a/keypair.go
+++ b/keypair.go
@@ -33,7 +33,7 @@ type Keypairs struct {
mutex sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair // not yet "confirmed by transport"
+ next *Keypair
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/logger.go b/logger.go
index 784235c..b8012aa 100644
--- a/logger.go
+++ b/logger.go
@@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger {
logger.Debug = log.New(logDebug,
"DEBUG: "+prepend,
- log.Ldate|log.Ltime|log.Lshortfile,
+ log.Ldate|log.Ltime,
)
logger.Info = log.New(logInfo,
diff --git a/noise-helpers.go b/noise-helpers.go
index 6e23d83..63e45b3 100644
--- a/noise-helpers.go
+++ b/noise-helpers.go
@@ -71,14 +71,13 @@ func isZero(val []byte) bool {
return acc == 1
}
+/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
func setZero(arr []byte) {
for i := range arr {
arr[i] = 0
}
}
-/* curve25519 wrappers */
-
func newPrivateKey() (sk NoisePrivateKey, err error) {
// clamping: https://cr.yp.to/ecdh.html
_, err = rand.Read(sk[:])
diff --git a/noise-types.go b/noise-types.go
index 58aa0c2..2635e01 100644
--- a/noise-types.go
+++ b/noise-types.go
@@ -30,7 +30,7 @@ func loadExactHex(dst []byte, src string) error {
return err
}
if len(slice) != len(dst) {
- return errors.New("Hex string does not fit the slice")
+ return errors.New("hex string does not fit the slice")
}
copy(dst, slice)
return nil
diff --git a/peer.go b/peer.go
index f49f806..d574c71 100644
--- a/peer.go
+++ b/peer.go
@@ -61,7 +61,7 @@ type Peer struct {
mutex sync.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
- stop chan struct{} // size 0, stop all go-routines in peer
+ stop chan struct{} // size 0, stop all go routines in peer
}
mac CookieGenerator
@@ -70,7 +70,7 @@ type Peer struct {
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() {
- return nil, errors.New("Device closed")
+ return nil, errors.New("device closed")
}
// lock resources
@@ -87,7 +87,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
- return nil, errors.New("Too many peers")
+ return nil, errors.New("too many peers")
}
// create peer
@@ -104,7 +104,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
_, ok := device.peers.keyMap[pk]
if ok {
- return nil, errors.New("Adding existing peer")
+ return nil, errors.New("adding existing peer")
}
device.peers.keyMap[pk] = peer
@@ -134,26 +134,26 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
defer peer.device.net.mutex.RUnlock()
if peer.device.net.bind == nil {
- return errors.New("No bind")
+ return errors.New("no bind")
}
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
- return errors.New("No known endpoint for peer")
+ return errors.New("no known endpoint for peer")
}
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
-/* Returns a short string identifier for logging
- */
func (peer *Peer) String() string {
- return fmt.Sprintf(
- "peer(%s)",
- base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
- )
+ base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ abbreviatedKey := "invalid"
+ if len(base64Key) == 44 {
+ abbreviatedKey = base64Key[0:4] + "..." + base64Key[40:44]
+ }
+ return fmt.Sprintf("peer(%s)", abbreviatedKey)
}
func (peer *Peer) Start() {
diff --git a/receive.go b/receive.go
index 60a2510..32ff512 100644
--- a/receive.go
+++ b/receive.go
@@ -600,20 +600,24 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check if using new key-pair
kp := &peer.keypairs
- kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
if kp.next == elem.keypair {
- old := kp.previous
- kp.previous = kp.current
- device.DeleteKeypair(old)
- kp.current = kp.next
- kp.next = nil
- peer.timersHandshakeComplete()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
+ kp.mutex.Lock()
+ if kp.next != elem.keypair {
+ kp.mutex.Unlock()
+ } else {
+ old := kp.previous
+ kp.previous = kp.current
+ device.DeleteKeypair(old)
+ kp.current = kp.next
+ kp.next = nil
+ kp.mutex.Unlock()
+ peer.timersHandshakeComplete()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
+ }
}
}
- kp.mutex.Unlock()
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
diff --git a/routing.go b/routing.go
deleted file mode 100644
index 77c9b1e..0000000
--- a/routing.go
+++ /dev/null
@@ -1,70 +0,0 @@
-/* SPDX-License-Identifier: GPL-2.0
- *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-import (
- "errors"
- "net"
- "sync"
-)
-
-type RoutingTable struct {
- IPv4 *Trie
- IPv6 *Trie
- mutex sync.RWMutex
-}
-
-func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
-
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.AllowedIPs(peer, allowed)
- allowed = table.IPv6.AllowedIPs(peer, allowed)
- return allowed
-}
-
-func (table *RoutingTable) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
-}
-
-func (table *RoutingTable) RemovePeer(peer *Peer) {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = table.IPv4.RemovePeer(peer)
- table.IPv6 = table.IPv6.RemovePeer(peer)
-}
-
-func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- switch len(ip) {
- case net.IPv6len:
- table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
- case net.IPv4len:
- table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
- default:
- panic(errors.New("Inserting unknown address type"))
- }
-}
-
-func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv4.Lookup(address)
-}
-
-func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.Lookup(address)
-}
diff --git a/tun_darwin.go b/tun_darwin.go
index 1d66c66..fa8efe0 100644
--- a/tun_darwin.go
+++ b/tun_darwin.go
@@ -224,7 +224,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
}
func (tun *NativeTun) Close() error {
- return tun.fd.Close()
+ err := tun.fd.Close()
+ close(tun.events)
+ return err
}
func (tun *NativeTun) setMTU(n int) error {
diff --git a/tun_linux.go b/tun_linux.go
index 18994cc..9f60d2b 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -392,6 +392,7 @@ func (tun *NativeTun) Close() error {
return err
}
tun.closingWriter.Write([]byte{0})
+ close(tun.events)
return nil
}
diff --git a/tun_windows.go b/tun_windows.go
index c0c9ff8..6eea5a3 100644
--- a/tun_windows.go
+++ b/tun_windows.go
@@ -125,7 +125,9 @@ func (f *NativeTUN) Events() chan TUNEvent {
}
func (f *NativeTUN) Close() error {
- return windows.Close(f.fd)
+ close(f.events)
+ err := windows.Close(f.fd)
+ return err
}
func (f *NativeTUN) Write(b []byte) (int, error) {
diff --git a/uapi.go b/uapi.go
index 4b2038b..90c400a 100644
--- a/uapi.go
+++ b/uapi.go
@@ -91,7 +91,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
- for _, ip := range device.routing.table.AllowedIPs(peer) {
+ for _, ip := range device.routing.table.EntriesForPeer(peer) {
send("allowed_ip=" + ip.String())
}
@@ -337,7 +337,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "replace_allowed_ips":
- logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer)
+ logDebug.Println("UAPI: Removing all allowed EntriesForPeer for peer:", peer)
if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
@@ -349,7 +349,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
device.routing.mutex.Lock()
- device.routing.table.RemovePeer(peer)
+ device.routing.table.RemoveByPeer(peer)
device.routing.mutex.Unlock()
case "allowed_ip":