aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.go132
-rw-r--r--src/device.go8
-rw-r--r--src/main.go25
-rw-r--r--src/routing.go6
-rw-r--r--src/send.go8
-rw-r--r--src/trie.go10
6 files changed, 109 insertions, 80 deletions
diff --git a/src/config.go b/src/config.go
index 3b91d00..2f8dc76 100644
--- a/src/config.go
+++ b/src/config.go
@@ -5,24 +5,22 @@ import (
"errors"
"fmt"
"io"
- "log"
"net"
"strconv"
+ "strings"
"time"
)
-/* TODO : use real error code
- * Many of which will be the same
+// #include <errno.h>
+import "C"
+
+/* TODO: More fine grained?
*/
const (
- ipcErrorNoPeer = 0
- ipcErrorNoKeyValue = 1
- ipcErrorInvalidKey = 2
- ipcErrorInvalidValue = 2
- ipcErrorInvalidPrivateKey = 3
- ipcErrorInvalidPublicKey = 4
- ipcErrorInvalidPort = 5
- ipcErrorInvalidIPAddress = 6
+ ipcErrorNoPeer = C.EPROTO
+ ipcErrorNoKeyValue = C.EPROTO
+ ipcErrorInvalidKey = C.EPROTO
+ ipcErrorInvalidValue = C.EPROTO
)
type IPCError struct {
@@ -78,7 +76,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
// send lines
for _, line := range lines {
- device.log.Debug.Println("config:", line)
+ device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
@@ -89,29 +87,26 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
-
+ logger := device.log.Debug
scanner := bufio.NewScanner(socket)
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
+ var peer *Peer
for scanner.Scan() {
- var key string
- var value string
- var peer *Peer
// Parse line
line := scanner.Text()
- if line == "\n" {
- break
+ if line == "" {
+ return nil
}
- fmt.Println(line)
- n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
- if n != 2 || err != nil {
- fmt.Println(err, n)
+ parts := strings.Split(line, "=")
+ if len(parts) != 2 {
+ device.log.Debug.Println(parts)
return &IPCError{Code: ipcErrorNoKeyValue}
}
+ key := parts[0]
+ value := parts[1]
+ logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key {
@@ -119,41 +114,60 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "private_key":
if value == "" {
+ device.mutex.Lock()
device.privateKey = NoisePrivateKey{}
+ device.mutex.Unlock()
} else {
+ device.mutex.Lock()
err := device.privateKey.FromHex(value)
+ device.mutex.Unlock()
if err != nil {
- return &IPCError{Code: ipcErrorInvalidPrivateKey}
+ logger.Println("Failed to set private_key:", err)
+ return &IPCError{Code: ipcErrorInvalidValue}
}
}
case "listen_port":
- _, err := fmt.Sscanf(value, "%ud", &device.address.Port)
- if err != nil {
- return &IPCError{Code: ipcErrorInvalidPort}
+ var port int
+ _, err := fmt.Sscanf(value, "%d", &port)
+ if err != nil || port > (1<<16) || port < 0 {
+ logger.Println("Failed to set listen_port:", err)
+ return &IPCError{Code: ipcErrorInvalidValue}
}
+ device.mutex.Lock()
+ if device.address == nil {
+ device.address = &net.UDPAddr{}
+ }
+ device.address.Port = port
+ device.mutex.Unlock()
case "fwmark":
- panic(nil) // not handled yet
+ logger.Println("FWMark not handled yet")
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
- return &IPCError{Code: ipcErrorInvalidPublicKey}
+ logger.Println("Failed to get peer by public_key:", err)
+ return &IPCError{Code: ipcErrorInvalidValue}
}
+ device.mutex.RLock()
found, ok := device.peers[pubKey]
+ device.mutex.RUnlock()
if ok {
peer = found
} else {
peer = device.NewPeer(pubKey)
}
+ if peer == nil {
+ panic(errors.New("bug: failed to find peer"))
+ }
case "replace_peers":
- if key == "true" {
+ if value == "true" {
device.RemoveAllPeers()
- } else if key == "false" {
} else {
+ logger.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
@@ -161,6 +175,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
/* Peer configuration */
if peer == nil {
+ logger.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer}
}
@@ -168,7 +183,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "remove":
peer.mutex.Lock()
- // device.RemovePeer(peer.publicKey)
+ device.RemovePeer(peer.handshake.remoteStatic)
+ peer.mutex.Unlock()
+ logger.Println("Remove peer")
peer = nil
case "preshared_key":
@@ -178,13 +195,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return peer.handshake.presharedKey.FromHex(value)
}()
if err != nil {
- return &IPCError{Code: ipcErrorInvalidPublicKey}
+ logger.Println("Failed to set preshared_key:", err)
+ return &IPCError{Code: ipcErrorInvalidValue}
}
case "endpoint":
ip := net.ParseIP(value)
if ip == nil {
- return &IPCError{Code: ipcErrorInvalidIPAddress}
+ logger.Println("Failed to set endpoint:", value)
+ return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
// peer.endpoint = ip FIX
@@ -193,6 +212,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil {
+ logger.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
@@ -200,24 +220,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Unlock()
case "replace_allowed_ips":
- if key == "true" {
+ if value == "true" {
device.routingTable.RemovePeer(peer)
- } else if key == "false" {
} else {
+ logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
+ logger.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
ones, _ := network.Mask.Size()
+ logger.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */
default:
+ logger.Println("Invalid key:", key)
return &IPCError{Code: ipcErrorInvalidKey}
}
}
@@ -226,49 +249,48 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
-func ipcListen(device *Device, socket io.ReadWriter) error {
+func ipcHandle(device *Device, socket net.Conn) {
- buffered := func(s io.ReadWriter) *bufio.ReadWriter {
- reader := bufio.NewReader(s)
- writer := bufio.NewWriter(s)
- return bufio.NewReadWriter(reader, writer)
- }(socket)
+ func() {
+ buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+ reader := bufio.NewReader(s)
+ writer := bufio.NewWriter(s)
+ return bufio.NewReadWriter(reader, writer)
+ }(socket)
- defer buffered.Flush()
+ defer buffered.Flush()
- for {
op, err := buffered.ReadString('\n')
if err != nil {
- return err
+ return
}
- log.Println(op)
switch op {
case "set=1\n":
+ device.log.Debug.Println("Config, set operation")
err := ipcSetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
- return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
- buffered.Flush()
+ break
case "get=1\n":
+ device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix
- return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
- buffered.Flush()
+ break
- case "\n":
default:
- return errors.New("handle this please")
+ device.log.Info.Println("Invalid UAPI operation:", op)
}
- }
+ }()
+ socket.Close()
}
diff --git a/src/device.go b/src/device.go
index a7a5c7b..52ac6a4 100644
--- a/src/device.go
+++ b/src/device.go
@@ -81,10 +81,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
-}
-
-func (device *Device) RemoveAllAllowedIps(peer *Peer) {
-
+ peer.Close()
}
func (device *Device) RemoveAllPeers() {
@@ -93,8 +90,7 @@ func (device *Device) RemoveAllPeers() {
for key, peer := range device.peers {
peer.mutex.Lock()
- device.routingTable.RemovePeer(peer)
delete(device.peers, key)
- peer.mutex.Unlock()
+ peer.Close()
}
}
diff --git a/src/main.go b/src/main.go
index 7c58972..9c76ff4 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,21 +1,28 @@
package main
import (
+ "fmt"
"log"
"net"
+ "os"
)
-/*
- *
- * TODO: Fix logging
+/* TODO: Fix logging
+ * TODO: Fix daemon
*/
func main() {
+
+ if len(os.Args) != 2 {
+ return
+ }
+ deviceName := os.Args[1]
+
// Open TUN device
// TODO: Fix capabilities
- tun, err := CreateTUN("test0")
+ tun, err := CreateTUN(deviceName)
log.Println(tun, err)
if err != nil {
return
@@ -25,19 +32,17 @@ func main() {
// Start configuration lister
- l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
+ socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
+ l, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatal("listen error:", err)
}
for {
- fd, err := l.Accept()
+ conn, err := l.Accept()
if err != nil {
log.Fatal("accept error:", err)
}
- go func(conn net.Conn) {
- err := ipcListen(device, conn)
- log.Println(err)
- }(fd)
+ go ipcHandle(device, conn)
}
}
diff --git a/src/routing.go b/src/routing.go
index 6a5e1f3..2a2e237 100644
--- a/src/routing.go
+++ b/src/routing.go
@@ -16,9 +16,9 @@ func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
- allowed := make([]net.IPNet, 10)
- table.IPv4.AllowedIPs(peer, allowed)
- table.IPv6.AllowedIPs(peer, allowed)
+ allowed := make([]net.IPNet, 0, 10)
+ allowed = table.IPv4.AllowedIPs(peer, allowed)
+ allowed = table.IPv6.AllowedIPs(peer, allowed)
return allowed
}
diff --git a/src/send.go b/src/send.go
index 4ff75db..ab75750 100644
--- a/src/send.go
+++ b/src/send.go
@@ -61,9 +61,11 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+ device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
+ device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
@@ -76,8 +78,6 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
continue
}
- device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
-
// lookup peer
var peer *Peer
@@ -85,10 +85,12 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
+ device.log.Debug.Println("New IPv4 packet:", packet, dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
+ device.log.Debug.Println("New IPv6 packet:", packet, dst)
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
@@ -97,7 +99,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
- return
+ continue
}
// insert into nonce/pre-handshake queue
diff --git a/src/trie.go b/src/trie.go
index 4049167..c2304b2 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -195,7 +195,10 @@ func (node *Trie) Count() uint {
return l + r
}
-func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
+func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
+ if node == nil {
+ return results
+ }
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
@@ -213,6 +216,7 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
}
results = append(results, mask)
}
- node.child[0].AllowedIPs(p, results)
- node.child[1].AllowedIPs(p, results)
+ results = node.child[0].AllowedIPs(p, results)
+ results = node.child[1].AllowedIPs(p, results)
+ return results
}