aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-06-04 21:48:15 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-06-04 21:48:15 +0200
commit1868d15914d6cd7cd57b90b7644b008ec16361b9 (patch)
treedbc788f49f433a5837db3c022facb19be38e4ea1 /src
parentTrie random test (diff)
downloadwireguard-go-1868d15914d6cd7cd57b90b7644b008ec16361b9.tar.xz
wireguard-go-1868d15914d6cd7cd57b90b7644b008ec16361b9.zip
Beginning work on TUN interface
And outbound routing I am not entirely convinced the use of net.IP is a good idea, since the internal representation of net.IP is a byte slice and all constructor functions in "net" return 16 byte slices (padded for IPv4), while the use in this project uses 4 byte slices. Which may be confusing.
Diffstat (limited to 'src')
-rw-r--r--src/config.go53
-rw-r--r--src/ip.go17
-rw-r--r--src/main.go26
-rw-r--r--src/peer.go10
-rw-r--r--src/routing.go55
-rw-r--r--src/trie.go40
-rw-r--r--src/trie_test.go63
-rw-r--r--src/tun.go8
-rw-r--r--src/tun_linux.go80
9 files changed, 290 insertions, 62 deletions
diff --git a/src/config.go b/src/config.go
index 62af67a..a61b940 100644
--- a/src/config.go
+++ b/src/config.go
@@ -7,6 +7,8 @@ import (
"io"
"log"
"net"
+ "strconv"
+ "time"
)
/* todo : use real error code
@@ -16,6 +18,7 @@ const (
ipcErrorNoPeer = 0
ipcErrorNoKeyValue = 1
ipcErrorInvalidKey = 2
+ ipcErrorInvalidValue = 2
ipcErrorInvalidPrivateKey = 3
ipcErrorInvalidPublicKey = 4
ipcErrorInvalidPort = 5
@@ -34,18 +37,16 @@ func (s *IPCError) ErrorCode() int {
return s.Code
}
-// Writes the configuration to the socket
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
}
-// Creates new config, from old and socket message
-func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
+func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
- dev.mutex.Lock()
- defer dev.mutex.Unlock()
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
for scanner.Scan() {
var key string
@@ -71,16 +72,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
case "private_key":
if value == "" {
- dev.privateKey = NoisePrivateKey{}
+ device.privateKey = NoisePrivateKey{}
} else {
- err := dev.privateKey.FromHex(value)
+ err := device.privateKey.FromHex(value)
if err != nil {
return &IPCError{Code: ipcErrorInvalidPrivateKey}
}
}
case "listen_port":
- _, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
+ _, err := fmt.Sscanf(value, "%ud", &device.listenPort)
if err != nil {
return &IPCError{Code: ipcErrorInvalidPort}
}
@@ -94,7 +95,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey}
}
- found, ok := dev.peers[pubKey]
+ found, ok := device.peers[pubKey]
if ok {
peer = found
} else {
@@ -102,14 +103,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
publicKey: pubKey,
}
peer = newPeer
- dev.peers[pubKey] = newPeer
+ device.peers[pubKey] = newPeer
}
case "replace_peers":
if key == "true" {
- dev.RemoveAllPeers()
+ device.RemoveAllPeers()
+ } else if key == "false" {
+ } else {
+ return &IPCError{Code: ipcErrorInvalidValue}
}
- // todo: else fail
default:
/* Peer configuration */
@@ -122,7 +125,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
case "remove":
peer.mutex.Lock()
- dev.RemovePeer(peer.publicKey)
+ device.RemovePeer(peer.publicKey)
peer = nil
case "preshared_key":
@@ -145,15 +148,29 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Unlock()
case "persistent_keepalive_interval":
- func() {
- peer.mutex.Lock()
- defer peer.mutex.Unlock()
- }()
+ secs, err := strconv.ParseInt(value, 10, 64)
+ if secs < 0 || err != nil {
+ return &IPCError{Code: ipcErrorInvalidValue}
+ }
+ peer.mutex.Lock()
+ peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
+ peer.mutex.Unlock()
case "replace_allowed_ips":
- // remove peer from trie
+ if key == "true" {
+ device.routingTable.RemovePeer(peer)
+ } else if key == "false" {
+ } else {
+ return &IPCError{Code: ipcErrorInvalidValue}
+ }
case "allowed_ip":
+ _, network, err := net.ParseCIDR(value)
+ if err != nil {
+ return &IPCError{Code: ipcErrorInvalidValue}
+ }
+ ones, _ := network.Mask.Size()
+ device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */
diff --git a/src/ip.go b/src/ip.go
new file mode 100644
index 0000000..3137891
--- /dev/null
+++ b/src/ip.go
@@ -0,0 +1,17 @@
+package main
+
+import (
+ "net"
+)
+
+const (
+ IPv4version = 4
+ IPv4offsetSrc = 12
+ IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+)
+
+const (
+ IPv6version = 6
+ IPv6offsetSrc = 8
+ IPv6offsetDst = IPv6offsetSrc + net.IPv6len
+)
diff --git a/src/main.go b/src/main.go
index 0f5016d..af336f0 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,11 +1,33 @@
package main
+import "fmt"
+
+func main() {
+ fd, err := CreateTUN("test0")
+ fmt.Println(fd, err)
+
+ queue := make(chan []byte, 1000)
+
+ var device Device
+
+ go OutgoingRoutingWorker(&device, queue)
+
+ for {
+ tmp := make([]byte, 1<<16)
+ n, err := fd.Read(tmp)
+ if err != nil {
+ break
+ }
+ queue <- tmp[:n]
+ }
+}
+
+/*
import (
"fmt"
"log"
"net"
)
-
func main() {
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
if err != nil {
@@ -24,5 +46,5 @@ func main() {
fmt.Println(err)
}(fd)
}
-
}
+*/
diff --git a/src/peer.go b/src/peer.go
index 7b2b2a6..db5e99f 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -3,6 +3,7 @@ package main
import (
"net"
"sync"
+ "time"
)
type KeyPair struct {
@@ -13,8 +14,9 @@ type KeyPair struct {
}
type Peer struct {
- mutex sync.RWMutex
- publicKey NoisePublicKey
- presharedKey NoiseSymmetricKey
- endpoint net.IP
+ mutex sync.RWMutex
+ publicKey NoisePublicKey
+ presharedKey NoiseSymmetricKey
+ endpoint net.IP
+ persistentKeepaliveInterval time.Duration
}
diff --git a/src/routing.go b/src/routing.go
index 99b180c..0aa111c 100644
--- a/src/routing.go
+++ b/src/routing.go
@@ -1,13 +1,12 @@
package main
import (
+ "errors"
+ "fmt"
+ "net"
"sync"
)
-/* Thread-safe high level functions for cryptkey routing.
- *
- */
-
type RoutingTable struct {
IPv4 *Trie
IPv6 *Trie
@@ -20,3 +19,51 @@ func (table *RoutingTable) RemovePeer(peer *Peer) {
table.IPv4 = table.IPv4.RemovePeer(peer)
table.IPv6 = table.IPv6.RemovePeer(peer)
}
+
+func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ switch len(ip) {
+ case net.IPv6len:
+ table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
+ case net.IPv4len:
+ table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
+ default:
+ panic(errors.New("Inserting unknown address type"))
+ }
+}
+
+func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv4.Lookup(address)
+}
+
+func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv6.Lookup(address)
+}
+
+func OutgoingRoutingWorker(device *Device, queue chan []byte) {
+ for {
+ packet := <-queue
+ switch packet[0] >> 4 {
+
+ case IPv4version:
+ dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer := device.routingTable.LookupIPv4(dst)
+ fmt.Println("IPv4", peer)
+
+ case IPv6version:
+ dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer := device.routingTable.LookupIPv6(dst)
+ fmt.Println("IPv6", peer)
+
+ default:
+ // todo: log
+ fmt.Println("Unknown IP version")
+ }
+ }
+}
diff --git a/src/trie.go b/src/trie.go
index 31a4d92..746c1b4 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -1,5 +1,9 @@
package main
+import (
+ "net"
+)
+
/* Binary trie
*
* Syncronization done seperatly
@@ -22,13 +26,13 @@ type Trie struct {
/* Finds length of matching prefix
* Maybe there is a faster way
*
- * Assumption: len(s1) == len(s2)
+ * Assumption: len(ip1) == len(ip2)
*/
-func commonBits(s1 []byte, s2 []byte) uint {
+func commonBits(ip1 net.IP, ip2 net.IP) uint {
var i uint
- size := uint(len(s1))
+ size := uint(len(ip1))
for i = 0; i < size; i += 1 {
- v := s1[i] ^ s2[i]
+ v := ip1[i] ^ ip2[i]
if v != 0 {
v >>= 1
if v == 0 {
@@ -93,17 +97,17 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node.child[0]
}
-func (node *Trie) choose(key []byte) byte {
- return (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+func (node *Trie) choose(ip net.IP) byte {
+ return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
-func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
+func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// At leaf
if node == nil {
return &Trie{
- bits: key,
+ bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
@@ -113,21 +117,21 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
// Traverse deeper
- common := commonBits(node.bits, key)
+ common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
}
- bit := node.choose(key)
- node.child[bit] = node.child[bit].Insert(key, cidr, peer)
+ bit := node.choose(ip)
+ node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
return node
}
// Split node
newNode := &Trie{
- bits: key,
+ bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
@@ -147,31 +151,31 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
// Create new parent for node & newNode
parent := &Trie{
- bits: key,
+ bits: ip,
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
- bit := parent.choose(key)
+ bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
}
-func (node *Trie) Lookup(key []byte) *Peer {
+func (node *Trie) Lookup(ip net.IP) *Peer {
var found *Peer
- size := uint(len(key))
- for node != nil && commonBits(node.bits, key) >= node.cidr {
+ size := uint(len(ip))
+ for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bit_at_byte == size {
break
}
- bit := node.choose(key)
+ bit := node.choose(ip)
node = node.child[bit]
}
return found
diff --git a/src/trie_test.go b/src/trie_test.go
index 35af0aa..9d53df3 100644
--- a/src/trie_test.go
+++ b/src/trie_test.go
@@ -1,6 +1,8 @@
package main
import (
+ "math/rand"
+ "net"
"testing"
)
@@ -55,6 +57,49 @@ func TestCommonBits(t *testing.T) {
}
}
+func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+ var trie *Trie
+ var peers []*Peer
+
+ rand.Seed(1)
+
+ const AddressLength = 4
+
+ for n := 0; n < peerNumber; n += 1 {
+ peers = append(peers, &Peer{})
+ }
+
+ for n := 0; n < addressNumber; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ cidr := uint(rand.Uint32() % (AddressLength * 8))
+ index := rand.Int() % peerNumber
+ trie = trie.Insert(addr[:], cidr, peers[index])
+ }
+
+ for n := 0; n < b.N; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ trie.Lookup(addr[:])
+ }
+}
+
+func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv6len, b)
+}
+
+func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv6len, b)
+}
+
/* Test ported from kernel implementation:
* selftest/routingtable.h
*/
@@ -91,10 +136,10 @@ func TestTrieIPv4(t *testing.T) {
insert(b, 192, 168, 4, 4, 32)
insert(c, 192, 168, 0, 0, 16)
insert(d, 192, 95, 5, 64, 27)
- insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */
+ insert(c, 192, 95, 5, 65, 27)
insert(e, 0, 0, 0, 0, 0)
insert(g, 64, 15, 112, 0, 20)
- insert(h, 64, 15, 123, 211, 25) /* maskself is required */
+ insert(h, 64, 15, 123, 211, 25)
insert(a, 10, 0, 0, 0, 25)
insert(b, 10, 0, 0, 128, 25)
insert(a, 10, 1, 0, 0, 30)
@@ -186,20 +231,6 @@ func TestTrieIPv6(t *testing.T) {
}
}
- /*
- assertNEQ := func(peer *Peer, a, b, c, d uint32) {
- var addr []byte
- addr = append(addr, expand(a)...)
- addr = append(addr, expand(b)...)
- addr = append(addr, expand(c)...)
- addr = append(addr, expand(d)...)
- p := trie.Lookup(addr)
- if p == peer {
- t.Error("Assert NEQ failed")
- }
- }
- */
-
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
diff --git a/src/tun.go b/src/tun.go
new file mode 100644
index 0000000..1a8bb82
--- /dev/null
+++ b/src/tun.go
@@ -0,0 +1,8 @@
+package main
+
+type TUN interface {
+ Read([]byte) (int, error)
+ Write([]byte) (int, error)
+ Name() string
+ MTU() uint
+}
diff --git a/src/tun_linux.go b/src/tun_linux.go
new file mode 100644
index 0000000..d545dfa
--- /dev/null
+++ b/src/tun_linux.go
@@ -0,0 +1,80 @@
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "os"
+ "strings"
+ "syscall"
+ "unsafe"
+)
+
+/* Platform dependent functions for interacting with
+ * TUN devices on linux systems
+ *
+ */
+
+const CloneDevicePath = "/dev/net/tun"
+
+const (
+ IFF_NO_PI = 0x1000
+ IFF_TUN = 0x1
+ IFNAMSIZ = 0x10
+ TUNSETIFF = 0x400454CA
+)
+
+type NativeTun struct {
+ fd *os.File
+ name string
+ mtu uint
+}
+
+func (tun *NativeTun) Name() string {
+ return tun.name
+}
+
+func (tun *NativeTun) MTU() uint {
+ return tun.mtu
+}
+
+func (tun *NativeTun) Write(d []byte) (int, error) {
+ return tun.fd.Write(d)
+}
+
+func (tun *NativeTun) Read(d []byte) (int, error) {
+ return tun.fd.Read(d)
+}
+
+func CreateTUN(name string) (TUN, error) {
+ // Open clone device
+ fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ // Prepare ifreq struct
+ var ifr [18]byte
+ var flags uint16 = IFF_TUN | IFF_NO_PI
+ nameBytes := []byte(name)
+ if len(nameBytes) >= IFNAMSIZ {
+ return nil, errors.New("Name size too long")
+ }
+ copy(ifr[:], nameBytes)
+ binary.LittleEndian.PutUint16(ifr[16:], flags)
+
+ // Create new device
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL,
+ uintptr(fd.Fd()), uintptr(TUNSETIFF),
+ uintptr(unsafe.Pointer(&ifr[0])))
+ if errno != 0 {
+ return nil, errors.New("Failed to create tun, ioctl call failed")
+ }
+
+ // Read name of interface
+ newName := string(ifr[:])
+ newName = newName[:strings.Index(newName, "\000")]
+ return &NativeTun{
+ fd: fd,
+ name: newName,
+ }, nil
+}