aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-10-07 22:35:23 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-10-07 22:35:23 +0200
commit2d856045a0dbfc15d38d738e2a9d159ba2a49a47 (patch)
tree9548c5b8eb8de12bd669985a075d1d5545dd38e2 /src
parentDefinition of platform specific socket bind (diff)
downloadwireguard-go-2d856045a0dbfc15d38d738e2a9d159ba2a49a47.tar.xz
wireguard-go-2d856045a0dbfc15d38d738e2a9d159ba2a49a47.zip
Begin incorporating new src cache into receive
Diffstat (limited to '')
-rw-r--r--src/conn.go106
-rw-r--r--src/conn_linux.go70
-rw-r--r--src/device.go33
-rw-r--r--src/main.go1
-rw-r--r--src/receive.go53
5 files changed, 165 insertions, 98 deletions
diff --git a/src/conn.go b/src/conn.go
index 60cd789..61be3bf 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -3,7 +3,6 @@ package main
import (
"errors"
"net"
- "time"
)
func parseEndpoint(s string) (*net.UDPAddr, error) {
@@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
-func updateUDPConn(device *Device) error {
+func ListenerClose(l *Listener) (err error) {
+ if l.active {
+ err = CloseIPv4Socket(l.sock)
+ l.active = false
+ }
+ return
+}
+
+func (l *Listener) Init() {
+ l.update = make(chan struct{}, 1)
+ ListenerClose(l)
+}
+
+func ListeningUpdate(device *Device) error {
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
- // close existing connection
+ // close existing sockets
- if netc.conn != nil {
- netc.conn.Close()
- netc.conn = nil
+ if err := ListenerClose(&netc.ipv4); err != nil {
+ return err
+ }
- // We need for that fd to be closed in all other go routines, which
- // means we have to wait. TODO: find less horrible way of doing this.
- time.Sleep(time.Second / 2)
+ if err := ListenerClose(&netc.ipv6); err != nil {
+ return err
}
- // open new connection
+ // open new sockets
if device.tun.isUp.Get() {
- // listen on new address
-
- conn, err := net.ListenUDP("udp", netc.addr)
- if err != nil {
- return err
+ // listen on IPv4
+
+ {
+ list := &netc.ipv6
+ sock, port, err := CreateIPv4Socket(netc.port)
+ if err != nil {
+ return err
+ }
+ netc.port = port
+ list.sock = sock
+ list.active = true
+
+ if err := SetMark(list.sock, netc.fwmark); err != nil {
+ ListenerClose(list)
+ return err
+ }
+ signalSend(list.update)
}
- // set fwmark
-
- err = SetMark(netc.conn, netc.fwmark)
- if err != nil {
- return err
+ // listen on IPv6
+
+ {
+ list := &netc.ipv6
+ sock, port, err := CreateIPv6Socket(netc.port)
+ if err != nil {
+ return err
+ }
+ netc.port = port
+ list.sock = sock
+ list.active = true
+
+ if err := SetMark(list.sock, netc.fwmark); err != nil {
+ ListenerClose(list)
+ return err
+ }
+ signalSend(list.update)
}
- // retrieve port (may have been chosen by kernel)
-
- addr := conn.LocalAddr()
- netc.conn = conn
- netc.addr, _ = net.ResolveUDPAddr(
- addr.Network(),
- addr.String(),
- )
-
- // notify goroutines
-
- signalSend(device.signal.newUDPConn)
+ // TODO: clear endpoint caches
}
return nil
}
-func closeUDPConn(device *Device) {
+func ListeningClose(device *Device) error {
netc := &device.net
netc.mutex.Lock()
- if netc.conn != nil {
- netc.conn.Close()
+ defer netc.mutex.Unlock()
+
+ if err := ListenerClose(&netc.ipv4); err != nil {
+ return err
}
- netc.mutex.Unlock()
- signalSend(device.signal.newUDPConn)
+ signalSend(netc.ipv4.update)
+
+ if err := ListenerClose(&netc.ipv6); err != nil {
+ return err
+ }
+ signalSend(netc.ipv6.update)
+
+ return nil
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index 64447a5..034fb8b 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -28,6 +28,7 @@ import "fmt"
type Endpoint struct {
// source (selected based on dst type)
// (could use RawSockaddrAny and unsafe)
+ // TODO: Merge
src6 unix.RawSockaddrInet6
src4 unix.RawSockaddrInet4
src4if int32
@@ -35,8 +36,14 @@ type Endpoint struct {
dst unix.RawSockaddrAny
}
-type IPv4Socket int
-type IPv6Socket int
+type Socket int
+
+/* Returns a byte representation of the source field(s)
+ * for use in "under load" cookie computations.
+ */
+func (endpoint *Endpoint) Source() []byte {
+ return nil
+}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
@@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
-func CreateIPv4Socket(port int) (IPv4Socket, error) {
+func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
// create socket
@@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
)
if err != nil {
- return -1, err
+ return -1, 0, err
+ }
+
+ addr := unix.SockaddrInet4{
+ Port: int(port),
}
// set sockopts and bind
if err := func() error {
-
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
@@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
return err
}
- addr := unix.SockaddrInet4{
- Port: port,
- }
return unix.Bind(fd, &addr)
-
}(); err != nil {
unix.Close(fd)
}
- return IPv4Socket(fd), err
+ return Socket(fd), uint16(addr.Port), err
}
-func CreateIPv6Socket(port int) (IPv6Socket, error) {
+func CloseIPv4Socket(sock Socket) error {
+ return unix.Close(int(sock))
+}
+
+func CloseIPv6Socket(sock Socket) error {
+ return unix.Close(int(sock))
+}
+
+func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
// create socket
@@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
)
if err != nil {
- return -1, err
+ return -1, 0, err
}
// set sockopts and bind
+ addr := unix.SockaddrInet6{
+ Port: int(port),
+ }
+
if err := func() error {
if err := unix.SetsockoptInt(
@@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
return err
}
- addr := unix.SockaddrInet6{
- Port: port,
- }
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
- return IPv6Socket(fd), err
+ return Socket(fd), uint16(addr.Port), err
}
func (end *Endpoint) ClearSrc() {
@@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
return errors.New("Unknown address family of source")
}
-func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
+func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
// contruct message header
@@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
return int(size), nil
}
-func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
+func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
// contruct message header
@@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
// recvmsg(sock, &mskhdr, 0)
- _, _, errno := unix.Syscall(
+ size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
@@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
)
if errno != 0 {
- return errno
+ return 0, errno
}
// update source cache
@@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
end.src6.Scope_id = cmsg.pktinfo.Ifindex
}
- return nil
+ return int(size), nil
}
-func SetMark(conn *net.UDPConn, value uint32) error {
- if conn == nil {
- return nil
- }
-
- file, err := conn.File()
- if err != nil {
- return err
- }
-
+func SetMark(sock Socket, value uint32) error {
return unix.SetsockoptInt(
- int(file.Fd()),
+ int(sock),
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
diff --git a/src/device.go b/src/device.go
index 61c87bc..509e6a7 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,13 +1,18 @@
package main
import (
- "net"
"runtime"
"sync"
"sync/atomic"
"time"
)
+type Listener struct {
+ sock Socket
+ active bool
+ update chan struct{}
+}
+
type Device struct {
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
@@ -22,8 +27,9 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
- addr *net.UDPAddr // UDP source address
- conn *net.UDPConn // UDP "connection"
+ ipv4 Listener
+ ipv6 Listener
+ port uint16
fwmark uint32
}
mutex sync.RWMutex
@@ -37,8 +43,9 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{} // halts all go routines
- newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
+ stop chan struct{} // halts all go routines
+ updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
+ updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
+
device.indices.Init()
+ device.net.ipv4.Init()
+ device.net.ipv6.Init()
device.ratelimiter.Init()
+
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
- // setup pools
+ // setup buffer pool
device.pool.messageBuffers = sync.Pool{
New: func() interface{} {
@@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals
device.signal.stop = make(chan struct{})
- device.signal.newUDPConn = make(chan struct{}, 1)
// start workers
@@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
-
+ go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
- go device.RoutineReadFromTUN()
- go device.RoutineReceiveIncomming()
-
+ go device.RoutineReceiveIncomming(&device.net.ipv4)
+ go device.RoutineReceiveIncomming(&device.net.ipv6)
return device
}
@@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.signal.stop)
- closeUDPConn(device)
+ ListeningClose(device)
}
func (device *Device) WaitChannel() chan struct{} {
diff --git a/src/main.go b/src/main.go
index 196a4c6..a05dbba 100644
--- a/src/main.go
+++ b/src/main.go
@@ -14,6 +14,7 @@ func printUsage() {
}
func main() {
+ test()
// parse arguments
diff --git a/src/receive.go b/src/receive.go
index 52c2718..60c0f2c 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -13,10 +13,10 @@ import (
)
type QueueHandshakeElement struct {
- msgType uint32
- packet []byte
- buffer *[MaxMessageSize]byte
- source *net.UDPAddr
+ msgType uint32
+ packet []byte
+ endpoint Endpoint
+ buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
@@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue(
}
}
-func (device *Device) RoutineReceiveIncomming() {
+func (device *Device) RoutineReceiveIncomming(IPVersion int) {
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
+ var listener *Listener
+
+ switch IPVersion {
+ case ipv4.Version:
+ listener = &device.net.ipv4
+ case ipv6.Version:
+ listener = &device.net.ipv6
+ default:
+ return
+ }
+
for {
// wait for new conn
@@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() {
case <-device.signal.stop:
return
- case <-device.signal.newUDPConn:
+ case <-listener.update:
- // fetch connection
+ // fetch new socket
device.net.mutex.RLock()
- conn := device.net.conn
+ sock := listener.sock
+ okay := listener.active
device.net.mutex.RUnlock()
- if conn == nil {
+ if !okay {
continue
}
@@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() {
buffer := device.GetMessageBuffer()
+ var size int
+ var err error
+
for {
// read next datagram
- size, raddr, err := conn.ReadFromUDP(buffer[:])
+ var endpoint Endpoint
+
+ if IPVersion == ipv6.Version {
+ size, err = endpoint.ReceiveIPv4(sock, buffer[:])
+ } else {
+ size, err = endpoint.ReceiveIPv6(sock, buffer[:])
+ }
if err != nil {
break
@@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() {
buffer = device.GetMessageBuffer()
continue
- // otherwise it is a handshake related packet
+ // otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
@@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- source: raddr,
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ endpoint: endpoint,
},
)
buffer = device.GetMessageBuffer()
@@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet
- logDebug.Println("Process cookie reply from:", elem.source.String())
-
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)