aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--README.md8
-rw-r--r--conn/boundif_windows.go (renamed from device/boundif_windows.go)21
-rw-r--r--conn/conn.go101
-rw-r--r--conn/conn_default.go (renamed from device/conn_default.go)15
-rw-r--r--conn/conn_linux.go (renamed from device/conn_linux.go)264
-rw-r--r--conn/mark_default.go (renamed from device/mark_default.go)4
-rw-r--r--conn/mark_unix.go (renamed from device/mark_unix.go)4
-rw-r--r--device/allowedips.go2
-rw-r--r--device/allowedips_rand_test.go2
-rw-r--r--device/allowedips_test.go2
-rw-r--r--device/bind_test.go16
-rw-r--r--device/bindsocketshim.go36
-rw-r--r--device/boundif_android.go2
-rw-r--r--device/conn.go187
-rw-r--r--device/constants.go6
-rw-r--r--device/cookie.go2
-rw-r--r--device/cookie_test.go2
-rw-r--r--device/device.go166
-rw-r--r--device/device_test.go148
-rw-r--r--device/endpoint_test.go2
-rw-r--r--device/indextable.go2
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go2
-rw-r--r--device/keypair.go12
-rw-r--r--device/logger.go2
-rw-r--r--device/misc.go2
-rw-r--r--device/noise-helpers.go2
-rw-r--r--device/noise-protocol.go191
-rw-r--r--device/noise-types.go11
-rw-r--r--device/noise_test.go4
-rw-r--r--device/peer.go29
-rw-r--r--device/peer_test.go43
-rw-r--r--device/pools.go2
-rw-r--r--device/queueconstants_android.go2
-rw-r--r--device/queueconstants_default.go2
-rw-r--r--device/queueconstants_ios.go2
-rw-r--r--device/receive.go11
-rw-r--r--device/send.go26
-rw-r--r--device/sticky_default.go12
-rw-r--r--device/sticky_linux.go215
-rw-r--r--device/timers.go2
-rw-r--r--device/tun.go2
-rw-r--r--device/tun_test.go2
-rw-r--r--device/uapi.go26
-rw-r--r--device/version.go2
-rw-r--r--go.mod8
-rw-r--r--go.sum13
-rw-r--r--ipc/uapi_bsd.go74
-rw-r--r--ipc/uapi_linux.go75
-rw-r--r--ipc/uapi_unix.go63
-rw-r--r--ipc/uapi_windows.go2
-rw-r--r--ipc/winpipe/file.go2
-rw-r--r--ipc/winpipe/mksyscall.go2
-rw-r--r--ipc/winpipe/pipe.go2
-rw-r--r--main.go24
-rw-r--r--main_windows.go2
-rw-r--r--ratelimiter/ratelimiter.go101
-rw-r--r--ratelimiter/ratelimiter_test.go58
-rw-r--r--replay/replay.go4
-rw-r--r--replay/replay_test.go2
-rw-r--r--rwcancel/fdset.go4
-rw-r--r--rwcancel/rwcancel.go6
-rw-r--r--rwcancel/rwcancel_windows.go8
-rw-r--r--rwcancel/select_default.go7
-rw-r--r--rwcancel/select_linux.go2
-rw-r--r--tai64n/tai64n.go13
-rw-r--r--tai64n/tai64n_test.go42
-rw-r--r--tun/operateonfd.go2
-rw-r--r--tun/tun.go2
-rw-r--r--tun/tun_darwin.go21
-rw-r--r--tun/tun_freebsd.go2
-rw-r--r--tun/tun_linux.go73
-rw-r--r--tun/tun_openbsd.go2
-rw-r--r--tun/tun_windows.go13
-rw-r--r--tun/tuntest/tuntest.go150
-rw-r--r--tun/wintun/iphlpapi/conversion_windows.go2
-rw-r--r--tun/wintun/iphlpapi/mksyscall.go2
-rw-r--r--tun/wintun/namespace_windows.go7
-rw-r--r--tun/wintun/namespaceapi/mksyscall.go2
-rw-r--r--tun/wintun/namespaceapi/namespaceapi_windows.go2
-rw-r--r--tun/wintun/nci/mksyscall.go2
-rw-r--r--tun/wintun/nci/nci_windows.go2
-rw-r--r--tun/wintun/registry/mksyscall.go2
-rw-r--r--tun/wintun/registry/registry_windows.go2
-rw-r--r--tun/wintun/registry/registry_windows_test.go2
-rw-r--r--tun/wintun/ring_windows.go30
-rw-r--r--tun/wintun/setupapi/mksyscall.go2
-rw-r--r--tun/wintun/setupapi/setupapi_windows.go2
-rw-r--r--tun/wintun/setupapi/setupapi_windows_test.go2
-rw-r--r--tun/wintun/setupapi/types32_windows.go (renamed from tun/wintun/setupapi/types_windows_386.go)4
-rw-r--r--tun/wintun/setupapi/types64_windows.go (renamed from tun/wintun/setupapi/types_windows_amd64.go)4
-rw-r--r--tun/wintun/setupapi/types_windows.go4
-rw-r--r--tun/wintun/setupapi/zsetupapi_windows_test.go2
-rw-r--r--tun/wintun/wintun_windows.go23
94 files changed, 1382 insertions, 1093 deletions
diff --git a/README.md b/README.md
index d73bf59..ea3d7cb 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0
```
-When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
+When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux
-This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
+This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
### macOS
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building
-This requires an installation of [go](https://golang.org) ≥ 1.12.
+This requires an installation of [go](https://golang.org) ≥ 1.13.
```
$ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License
- Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
diff --git a/device/boundif_windows.go b/conn/boundif_windows.go
index 6908415..53a8f09 100644
--- a/device/boundif_windows.go
+++ b/conn/boundif_windows.go
@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
import (
"encoding/binary"
- "errors"
"unsafe"
"golang.org/x/sys/windows"
@@ -18,17 +17,13 @@ const (
sockoptIPV6_UNICAST_IF = 31
)
-func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
- if device.net.bind == nil {
- return errors.New("Bind is not yet initialized")
- }
-
- sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+ sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
if err != nil {
return err
}
- device.net.bind.(*nativeBind).blackhole4 = blackhole
+ bind.blackhole4 = blackhole
return nil
}
-func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
+func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
if err != nil {
return err
}
- device.net.bind.(*nativeBind).blackhole6 = blackhole
+ bind.blackhole6 = blackhole
return nil
}
diff --git a/conn/conn.go b/conn/conn.go
new file mode 100644
index 0000000..16311e4
--- /dev/null
+++ b/conn/conn.go
@@ -0,0 +1,101 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+// Package conn implements WireGuard's network connections.
+package conn
+
+import (
+ "errors"
+ "net"
+ "strings"
+)
+
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
+type Bind interface {
+ // LastMark reports the last mark set for this Bind.
+ LastMark() uint32
+
+ // SetMark sets the mark for each packet sent through this Bind.
+ // This mark is passed to the kernel as the socket option SO_MARK.
+ SetMark(mark uint32) error
+
+ // ReceiveIPv6 reads an IPv6 UDP packet into b.
+ //
+ // It reports the number of bytes read, n,
+ // the packet source address ep,
+ // and any error.
+ ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
+
+ // ReceiveIPv4 reads an IPv4 UDP packet into b.
+ //
+ // It reports the number of bytes read, n,
+ // the packet source address ep,
+ // and any error.
+ ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
+
+ // Send writes a packet b to address ep.
+ Send(b []byte, ep Endpoint) error
+
+ // Close closes the Bind connection.
+ Close() error
+}
+
+// CreateBind creates a Bind bound to a port.
+//
+// The value actualPort reports the actual port number the Bind
+// object gets bound to.
+func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
+ return createBind(port)
+}
+
+// BindToInterface is implemented by Bind objects that support being
+// tied to a single network interface.
+type BindToInterface interface {
+ BindToInterface4(interfaceIndex uint32, blackhole bool) error
+ BindToInterface6(interfaceIndex uint32, blackhole bool) error
+}
+
+// An Endpoint maintains the source/destination caching for a peer.
+//
+// dst : the remote address of a peer ("endpoint" in uapi terminology)
+// src : the local address from which datagrams originate going to the peer
+type Endpoint interface {
+ ClearSrc() // clears the source address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() net.IP
+ SrcIP() net.IP
+}
+
+func parseEndpoint(s string) (*net.UDPAddr, error) {
+ // ensure that the host is an IP address
+
+ host, _, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
+ // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
+ // trying to make sure with a small sanity test that this is a real IP address and
+ // not something that's likely to incur DNS lookups.
+ host = host[:i]
+ }
+ if ip := net.ParseIP(host); ip == nil {
+ return nil, errors.New("Failed to parse IP address: " + host)
+ }
+
+ // parse address and port
+
+ addr, err := net.ResolveUDPAddr("udp", s)
+ if err != nil {
+ return nil, err
+ }
+ ip4 := addr.IP.To4()
+ if ip4 != nil {
+ addr.IP = ip4
+ }
+ return addr, err
+}
diff --git a/device/conn_default.go b/conn/conn_default.go
index 661f57d..96ef137 100644
--- a/device/conn_default.go
+++ b/conn/conn_default.go
@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
import (
"net"
@@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string {
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
-
- // listen
-
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
- // retrieve port
-
+ // Retrieve port.
+ // TODO(crawshaw): under what circumstances is this necessary?
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
@@ -100,7 +97,7 @@ func extractErrno(err error) error {
return syscallErr.Err
}
-func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
+func createBind(uport uint16) (Bind, uint16, error) {
var err error
var bind nativeBind
@@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error {
return err2
}
+func (bind *nativeBind) LastMark() uint32 { return 0 }
+
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
diff --git a/device/conn_linux.go b/conn/conn_linux.go
index f74ad51..08c8949 100644
--- a/device/conn_linux.go
+++ b/conn/conn_linux.go
@@ -2,19 +2,10 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- *
- * This implements userspace semantics of "sticky sockets", modeled after
- * WireGuard's kernelspace implementation. This is more or less a straight port
- * of the sticky-sockets.c example code:
- * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
- *
- * Currently there is no way to achieve this within the net package:
- * See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
import (
"errors"
@@ -25,7 +16,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/rwcancel"
)
const (
@@ -33,8 +23,8 @@ const (
)
type IPv4Source struct {
- src [4]byte
- ifindex int32
+ Src [4]byte
+ Ifindex int32
}
type IPv6Source struct {
@@ -43,11 +33,16 @@ type IPv6Source struct {
}
type NativeEndpoint struct {
+ sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
}
+func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
+func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
+func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
+
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
@@ -65,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
}
type nativeBind struct {
- sock4 int
- sock6 int
- netlinkSock int
- netlinkCancel *rwcancel.RWCancel
- lastMark uint32
+ sock4 int
+ sock6 int
+ lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
@@ -110,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
return nil, errors.New("Invalid IP address")
}
-func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
- if err != nil {
- return -1, err
- }
- saddr := &unix.SockaddrNetlink{
- Family: unix.AF_NETLINK,
- Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
- }
- err = unix.Bind(sock, saddr)
- if err != nil {
- unix.Close(sock)
- return -1, err
- }
- return sock, nil
-
-}
-
-func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
+func createBind(port uint16) (Bind, uint16, error) {
var err error
var bind nativeBind
var newPort uint16
- bind.netlinkSock, err = createNetlinkRouteSocket()
- if err != nil {
- return nil, 0, err
- }
- bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
- if err != nil {
- unix.Close(bind.netlinkSock)
- return nil, 0, err
- }
-
- go bind.routineRouteListener(device)
-
- // attempt ipv6 bind, update port if succesful
-
+ // Attempt ipv6 bind, update port if successful.
bind.sock6, newPort, err = create6(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
- bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
}
- // attempt ipv4 bind, update port if succesful
-
+ // Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
- bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
@@ -177,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
return &bind, port, nil
}
+func (bind *nativeBind) LastMark() uint32 {
+ return bind.lastMark
+}
+
func (bind *nativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
@@ -215,22 +178,18 @@ func closeUnblock(fd int) error {
}
func (bind *nativeBind) Close() error {
- var err1, err2, err3 error
+ var err1, err2 error
if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
- err3 = bind.netlinkCancel.Cancel()
if err1 != nil {
return err1
}
- if err2 != nil {
- return err2
- }
- return err3
+ return err2
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
@@ -277,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
- end.src4().src[0],
- end.src4().src[1],
- end.src4().src[2],
- end.src4().src[3],
+ end.src4().Src[0],
+ end.src4().Src[1],
+ end.src4().Src[2],
+ end.src4().Src[3],
)
} else {
return end.src6().src[:]
@@ -477,12 +436,14 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
- Spec_dst: end.src4().src,
- Ifindex: end.src4().ifindex,
+ Spec_dst: end.src4().Src,
+ Ifindex: end.src4().Ifindex,
},
}
+ end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+ end.Unlock()
if err == nil {
return nil
@@ -493,7 +454,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
+ end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+ end.Unlock()
}
return err
@@ -522,7 +485,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0
}
+ end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+ end.Unlock()
if err == nil {
return nil
@@ -533,7 +498,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
+ end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+ end.Unlock()
}
return err
@@ -541,7 +508,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- // contruct message header
+ // construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
@@ -564,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().src = cmsg.pktinfo.Spec_dst
- end.src4().ifindex = cmsg.pktinfo.Ifindex
+ end.src4().Src = cmsg.pktinfo.Spec_dst
+ end.src4().Ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
@@ -573,7 +540,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- // contruct message header
+ // construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
@@ -602,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil
}
-
-func (bind *nativeBind) routineRouteListener(device *Device) {
- type peerEndpointPtr struct {
- peer *Peer
- endpoint *Endpoint
- }
- var reqPeer map[uint32]peerEndpointPtr
- var reqPeerLock sync.Mutex
-
- defer unix.Close(bind.netlinkSock)
-
- for msg := make([]byte, 1<<16); ; {
- var err error
- var msgn int
- for {
- msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
- if err == nil || !rwcancel.RetryAfterError(err) {
- break
- }
- if !bind.netlinkCancel.ReadyRead() {
- return
- }
- }
- if err != nil {
- return
- }
-
- for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
-
- hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
-
- if uint(hdr.Len) > uint(len(remain)) {
- break
- }
-
- switch hdr.Type {
- case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
- if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
- if uint(len(remain)) < uint(hdr.Len) {
- break
- }
- if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
- attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
- for {
- if uint(len(attr)) < uint(unix.SizeofRtAttr) {
- break
- }
- attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
- if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
- break
- }
- if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
- ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
- reqPeerLock.Lock()
- if reqPeer == nil {
- reqPeerLock.Unlock()
- break
- }
- pePtr, ok := reqPeer[hdr.Seq]
- reqPeerLock.Unlock()
- if !ok {
- break
- }
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
- break
- }
- if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
- pePtr.peer.Unlock()
- break
- }
- pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
- pePtr.peer.Unlock()
- }
- attr = attr[attrhdr.Len:]
- }
- }
- break
- }
- reqPeerLock.Lock()
- reqPeer = make(map[uint32]peerEndpointPtr)
- reqPeerLock.Unlock()
- go func() {
- device.peers.RLock()
- i := uint32(1)
- for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
- peer.RUnlock()
- continue
- }
- if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
- peer.RUnlock()
- break
- }
- nlmsg := struct {
- hdr unix.NlMsghdr
- msg unix.RtMsg
- dsthdr unix.RtAttr
- dst [4]byte
- srchdr unix.RtAttr
- src [4]byte
- markhdr unix.RtAttr
- mark uint32
- }{
- unix.NlMsghdr{
- Type: uint16(unix.RTM_GETROUTE),
- Flags: unix.NLM_F_REQUEST,
- Seq: i,
- },
- unix.RtMsg{
- Family: unix.AF_INET,
- Dst_len: 32,
- Src_len: 32,
- },
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_DST,
- },
- peer.endpoint.(*NativeEndpoint).dst4().Addr,
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_SRC,
- },
- peer.endpoint.(*NativeEndpoint).src4().src,
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_MARK,
- },
- uint32(bind.lastMark),
- }
- nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
- reqPeerLock.Lock()
- reqPeer[i] = peerEndpointPtr{
- peer: peer,
- endpoint: &peer.endpoint,
- }
- reqPeerLock.Unlock()
- peer.RUnlock()
- i++
- _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
- if err != nil {
- break
- }
- }
- device.peers.RUnlock()
- }()
- }
- remain = remain[hdr.Len:]
- }
- }
-}
diff --git a/device/mark_default.go b/conn/mark_default.go
index 7de2524..f57215a 100644
--- a/device/mark_default.go
+++ b/conn/mark_default.go
@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
func (bind *nativeBind) SetMark(mark uint32) error {
return nil
diff --git a/device/mark_unix.go b/conn/mark_unix.go
index 669b328..19ec2af 100644
--- a/device/mark_unix.go
+++ b/conn/mark_unix.go
@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
import (
"runtime"
diff --git a/device/allowedips.go b/device/allowedips.go
index efc27c0..143bda3 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 59c10f7..3947830 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 075ff06..005df48 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/bind_test.go b/device/bind_test.go
index 0c2e2cf..339adbe 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,15 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
-import "errors"
+import (
+ "errors"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
type DummyDatagram struct {
msg []byte
- endpoint Endpoint
+ endpoint conn.Endpoint
world bool // better type
}
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end Endpoint) error {
+func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go
new file mode 100644
index 0000000..c2555b8
--- /dev/null
+++ b/device/bindsocketshim.go
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "errors"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
+func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ if device.net.bind == nil {
+ return errors.New("Bind is not yet initialized")
+ }
+
+ if iface, ok := device.net.bind.(conn.BindToInterface); ok {
+ return iface.BindToInterface4(interfaceIndex, blackhole)
+ }
+ return nil
+}
+
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
+func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ if device.net.bind == nil {
+ return errors.New("Bind is not yet initialized")
+ }
+
+ if iface, ok := device.net.bind.(conn.BindToInterface); ok {
+ return iface.BindToInterface6(interfaceIndex, blackhole)
+ }
+ return nil
+}
diff --git a/device/boundif_android.go b/device/boundif_android.go
index 6d0fecf..a4be8de 100644
--- a/device/boundif_android.go
+++ b/device/boundif_android.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/conn.go b/device/conn.go
deleted file mode 100644
index 7b341f6..0000000
--- a/device/conn.go
+++ /dev/null
@@ -1,187 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "errors"
- "net"
- "strings"
-
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
-)
-
-const (
- ConnRoutineNumber = 2
-)
-
-/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
- */
-type Bind interface {
- SetMark(value uint32) error
- ReceiveIPv6(buff []byte) (int, Endpoint, error)
- ReceiveIPv4(buff []byte) (int, Endpoint, error)
- Send(buff []byte, end Endpoint) error
- Close() error
-}
-
-/* An Endpoint maintains the source/destination caching for a peer
- *
- * dst : the remote address of a peer ("endpoint" in uapi terminology)
- * src : the local address from which datagrams originate going to the peer
- */
-type Endpoint interface {
- ClearSrc() // clears the source address
- SrcToString() string // returns the local source address (ip:port)
- DstToString() string // returns the destination address (ip:port)
- DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
-}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}
-
-func unsafeCloseBind(device *Device) error {
- var err error
- netc := &device.net
- if netc.bind != nil {
- err = netc.bind.Close()
- netc.bind = nil
- }
- netc.stopping.Wait()
- return err
-}
-
-func (device *Device) BindSetMark(mark uint32) error {
-
- device.net.Lock()
- defer device.net.Unlock()
-
- // check if modified
-
- if device.net.fwmark == mark {
- return nil
- }
-
- // update fwmark on existing bind
-
- device.net.fwmark = mark
- if device.isUp.Get() && device.net.bind != nil {
- if err := device.net.bind.SetMark(mark); err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- }
- device.peers.RUnlock()
-
- return nil
-}
-
-func (device *Device) BindUpdate() error {
-
- device.net.Lock()
- defer device.net.Unlock()
-
- // close existing sockets
-
- if err := unsafeCloseBind(device); err != nil {
- return err
- }
-
- // open new sockets
-
- if device.isUp.Get() {
-
- // bind to new port
-
- var err error
- netc := &device.net
- netc.bind, netc.port, err = CreateBind(netc.port, device)
- if err != nil {
- netc.bind = nil
- netc.port = 0
- return err
- }
-
- // set fwmark
-
- if netc.fwmark != 0 {
- err = netc.bind.SetMark(netc.fwmark)
- if err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- }
- device.peers.RUnlock()
-
- // start receiving routines
-
- device.net.starting.Add(ConnRoutineNumber)
- device.net.stopping.Add(ConnRoutineNumber)
- go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
- go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
- device.net.starting.Wait()
-
- device.log.Debug.Println("UDP bind has been updated")
- }
-
- return nil
-}
-
-func (device *Device) BindClose() error {
- device.net.Lock()
- err := unsafeCloseBind(device)
- device.net.Unlock()
- return err
-}
diff --git a/device/constants.go b/device/constants.go
index e316f32..1a4b8ea 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,8 +12,8 @@ import (
/* Specification constants */
const (
- RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
- RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ RekeyAfterMessages = (1 << 60)
+ RejectAfterMessages = (1 << 64) - (1 << 13) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
diff --git a/device/cookie.go b/device/cookie.go
index f134128..c658ca3 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 79a6a86..7e4c362 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/device.go b/device/device.go
index 569c5a8..11119f9 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,15 +11,14 @@ import (
"sync/atomic"
"time"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
+ "golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
-const (
- DeviceRoutineNumberPerCPU = 3
- DeviceRoutineNumberAdditional = 2
-)
-
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
@@ -39,9 +38,10 @@ type Device struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
- bind Bind // bind interface
- port uint16 // listening port
- fwmark uint32 // mark value (0 = disabled)
+ bind conn.Bind // bind interface
+ netlinkCancel *rwcancel.RWCancel
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
@@ -236,23 +236,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations
- rmKey := device.staticIdentity.privateKey.IsZero()
-
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
- for key, peer := range device.peers.keyMap {
+ for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
-
- if rmKey {
- handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
- } else {
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
- }
-
- if isZero(handshake.precomputedStaticStatic[:]) {
- unsafeRemovePeer(device, peer, key)
- } else {
- expiredPeers = append(expiredPeers, peer)
- }
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ expiredPeers = append(expiredPeers, peer)
}
for _, peer := range lockedPeers {
@@ -311,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait()
- device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
- device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 {
+ device.state.starting.Add(3)
+ device.state.stopping.Add(3)
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
+ device.state.starting.Add(2)
+ device.state.stopping.Add(2)
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
@@ -425,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
}
device.peers.RUnlock()
}
+
+func unsafeCloseBind(device *Device) error {
+ var err error
+ netc := &device.net
+ if netc.netlinkCancel != nil {
+ netc.netlinkCancel.Cancel()
+ }
+ if netc.bind != nil {
+ err = netc.bind.Close()
+ netc.bind = nil
+ }
+ netc.stopping.Wait()
+ return err
+}
+
+func (device *Device) BindSetMark(mark uint32) error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // check if modified
+
+ if device.net.fwmark == mark {
+ return nil
+ }
+
+ // update fwmark on existing bind
+
+ device.net.fwmark = mark
+ if device.isUp.Get() && device.net.bind != nil {
+ if err := device.net.bind.SetMark(mark); err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ return nil
+}
+
+func (device *Device) BindUpdate() error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // close existing sockets
+
+ if err := unsafeCloseBind(device); err != nil {
+ return err
+ }
+
+ // open new sockets
+
+ if device.isUp.Get() {
+
+ // bind to new port
+
+ var err error
+ netc := &device.net
+ netc.bind, netc.port, err = conn.CreateBind(netc.port)
+ if err != nil {
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+ netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ if err != nil {
+ netc.bind.Close()
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+
+ // set fwmark
+
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
+ if err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ // start receiving routines
+
+ device.net.starting.Add(2)
+ device.net.stopping.Add(2)
+ go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
+ go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+ device.net.starting.Wait()
+
+ device.log.Debug.Println("UDP bind has been updated")
+ }
+
+ return nil
+}
+
+func (device *Device) BindClose() error {
+ device.net.Lock()
+ err := unsafeCloseBind(device)
+ device.net.Unlock()
+ return err
+}
diff --git a/device/device_test.go b/device/device_test.go
index 14cc605..5ea5410 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,15 +8,12 @@ package device
import (
"bufio"
"bytes"
- "encoding/binary"
- "io"
"net"
- "os"
"strings"
"testing"
"time"
- "golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
func TestTwoDevicePing(t *testing.T) {
@@ -29,7 +26,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
- tun1 := NewChannelTUN()
+ tun1 := tuntest.NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
@@ -45,7 +42,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
- tun2 := NewChannelTUN()
+ tun2 := tuntest.NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
@@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511`
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
+ msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
@@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511`
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
+ msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
@@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511`
})
}
-func ping(dst, src net.IP) []byte {
- localPort := uint16(1337)
- seq := uint16(0)
-
- payload := make([]byte, 4)
- binary.BigEndian.PutUint16(payload[0:], localPort)
- binary.BigEndian.PutUint16(payload[2:], seq)
-
- return genICMPv4(payload, dst, src)
-}
-
-// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
-func checksum(buf []byte, initial uint16) uint16 {
- v := uint32(initial)
- for i := 0; i < len(buf)-1; i += 2 {
- v += uint32(binary.BigEndian.Uint16(buf[i:]))
- }
- if len(buf)%2 == 1 {
- v += uint32(buf[len(buf)-1]) << 8
- }
- for v > 0xffff {
- v = (v >> 16) + (v & 0xffff)
- }
- return ^uint16(v)
-}
-
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
- const (
- icmpv4ProtocolNumber = 1
- icmpv4Echo = 8
- icmpv4ChecksumOffset = 2
- icmpv4Size = 8
- ipv4Size = 20
- ipv4TotalLenOffset = 2
- ipv4ChecksumOffset = 10
- ttl = 65
- )
-
- hdr := make([]byte, ipv4Size+icmpv4Size)
-
- ip := hdr[0:ipv4Size]
- icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
-
- // https://tools.ietf.org/html/rfc792
- icmpv4[0] = icmpv4Echo // type
- icmpv4[1] = 0 // code
- chksum := ^checksum(icmpv4, checksum(payload, 0))
- binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
-
- // https://tools.ietf.org/html/rfc760 section 3.1
- length := uint16(len(hdr) + len(payload))
- ip[0] = (4 << 4) | (ipv4Size / 4)
- binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
- ip[8] = ttl
- ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
- chksum = ^checksum(ip[:], 0)
- binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
-
- var v []byte
- v = append(v, hdr...)
- v = append(v, payload...)
- return []byte(v)
-}
-
-// TODO(crawshaw): find a reusable home for this. package devicetest?
-type ChannelTUN struct {
- Inbound chan []byte // incoming packets, closed on TUN close
- Outbound chan []byte // outbound packets, blocks forever on TUN close
-
- closed chan struct{}
- events chan tun.Event
- tun chTun
-}
-
-func NewChannelTUN() *ChannelTUN {
- c := &ChannelTUN{
- Inbound: make(chan []byte),
- Outbound: make(chan []byte),
- closed: make(chan struct{}),
- events: make(chan tun.Event, 1),
- }
- c.tun.c = c
- c.events <- tun.EventUp
- return c
-}
-
-func (c *ChannelTUN) TUN() tun.Device {
- return &c.tun
-}
-
-type chTun struct {
- c *ChannelTUN
-}
-
-func (t *chTun) File() *os.File { return nil }
-
-func (t *chTun) Read(data []byte, offset int) (int, error) {
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
- }
-}
-
-// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
- if offset == -1 {
- close(t.c.closed)
- close(t.c.events)
- return 0, io.EOF
- }
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
- }
-}
-
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
-func (t *chTun) Close() error {
- t.Write(nil, -1)
- return nil
-}
-
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 1896790..e66d493 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/indextable.go b/device/indextable.go
index 4cba970..5e10eef 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/ip.go b/device/ip.go
index 9d4fb74..3bc6929 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index cb8dbab..1a3bc87 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/keypair.go b/device/keypair.go
index 9c78fa9..2f2f222 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
+ "sync/atomic"
"time"
+ "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair
}
+func (kp *Keypairs) storeNext(next *Keypair) {
+ atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
+}
+
+func (kp *Keypairs) loadNext() *Keypair {
+ return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+}
+
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()
diff --git a/device/logger.go b/device/logger.go
index 7c8b704..3c4d744 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/misc.go b/device/misc.go
index a38d1c1..30d1156 100644
--- a/device/misc.go
+++ b/device/misc.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index f5e4b4b..b3b5acf 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 88c6aae..be92b4b 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,29 +1,51 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
+ "fmt"
"sync"
"time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
+type handshakeState int
+
+// TODO(crawshaw): add commentary describing each state and the transitions
const (
- HandshakeZeroed = iota
- HandshakeInitiationCreated
- HandshakeInitiationConsumed
- HandshakeResponseCreated
- HandshakeResponseConsumed
+ handshakeZeroed = handshakeState(iota)
+ handshakeInitiationCreated
+ handshakeInitiationConsumed
+ handshakeResponseCreated
+ handshakeResponseConsumed
)
+func (hs handshakeState) String() string {
+ switch hs {
+ case handshakeZeroed:
+ return "handshakeZeroed"
+ case handshakeInitiationCreated:
+ return "handshakeInitiationCreated"
+ case handshakeInitiationConsumed:
+ return "handshakeInitiationConsumed"
+ case handshakeResponseCreated:
+ return "handshakeResponseCreated"
+ case handshakeResponseConsumed:
+ return "handshakeResponseConsumed"
+ default:
+ return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+ }
+}
+
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -39,13 +61,13 @@ const (
)
const (
- MessageInitiationSize = 148 // size of handshake initation message
+ MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
- MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
+ MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
- MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
+ MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
)
const (
@@ -95,7 +117,7 @@ type MessageCookieReply struct {
}
type Handshake struct {
- state int
+ state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
@@ -135,7 +157,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
- h.state = HandshakeZeroed
+ h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
@@ -154,6 +176,7 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+ var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +185,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errors.New("static shared secret is zero")
- }
-
// create ephemeral key
-
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,59 +194,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
- // assign index
-
- device.indexTable.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
-
- if err != nil {
- return nil, err
- }
-
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
- Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
-
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- ss[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
- }()
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if isZero(ss[:]) {
+ return nil, errZeroECDHResult
+ }
+ var key [chacha20poly1305.KeySize]byte
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:])
// encrypt timestamp
-
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errZeroECDHResult
+ }
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
timestamp := tai64n.Now()
- func() {
- var key [chacha20poly1305.KeySize]byte
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
- }()
+ aead, _ = chacha20poly1305.New(key[:])
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+
+ // assign index
+ device.indexTable.Delete(handshake.localIndex)
+ msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+ if err != nil {
+ return nil, err
+ }
+ handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
- handshake.state = HandshakeInitiationCreated
+ handshake.state = handshakeInitiationCreated
return &msg, nil
}
@@ -250,16 +265,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
-
var err error
var peerPK NoisePublicKey
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- KDF2(&chainKey, &key, chainKey[:], ss[:])
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
- }()
+ var key [chacha20poly1305.KeySize]byte
+ ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if isZero(ss[:]) {
+ return nil
+ }
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return nil
}
@@ -273,23 +288,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
handshake := &peer.handshake
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil
- }
// verify identity
var timestamp tai64n.Timestamp
- var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ handshake.mutex.RUnlock()
+ return nil
+ }
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
- aead, _ := chacha20poly1305.New(key[:])
+ aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -299,11 +315,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood
- var ok bool
- ok = timestamp.After(handshake.lastTimestamp)
- ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
+ replay := !timestamp.After(handshake.lastTimestamp)
+ flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
- if !ok {
+ if replay {
+ device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
+ return nil
+ }
+ if flood {
+ device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
return nil
}
@@ -322,7 +342,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
- handshake.state = HandshakeInitiationConsumed
+ handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
@@ -337,7 +357,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationConsumed {
+ if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
@@ -393,7 +413,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Empty[:])
}()
- handshake.state = HandshakeResponseCreated
+ handshake.state = handshakeResponseCreated
return &msg, nil
}
@@ -423,7 +443,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
- if handshake.state != HandshakeInitiationCreated {
+ if handshake.state != handshakeInitiationCreated {
return false
}
@@ -484,7 +504,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
- handshake.state = HandshakeResponseConsumed
+ handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
@@ -509,7 +529,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
- if handshake.state == HandshakeResponseConsumed {
+ if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
@@ -517,7 +537,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil,
)
isInitiator = true
- } else if handshake.state == HandshakeResponseCreated {
+ } else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
@@ -526,7 +546,7 @@ func (peer *Peer) BeginSymmetricSession() error {
)
isInitiator = false
} else {
- return errors.New("invalid state for keypair derivation")
+ return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
@@ -534,7 +554,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
- peer.handshake.state = HandshakeZeroed
+ peer.handshake.state = handshakeZeroed
// create AEAD instances
@@ -564,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -578,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -589,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.next != receivedKeypair {
+
+ if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.loadNext()
+ keypairs.storeNext(nil)
return true
}
diff --git a/device/noise-types.go b/device/noise-types.go
index 6b1f16f..f793ef5 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return
}
+func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
+ err = loadExactHex(key[:], src)
+ if key.IsZero() {
+ return
+ }
+ key.clamp()
+ return
+}
+
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}
diff --git a/device/noise_test.go b/device/noise_test.go
index 6ba3f2e..ce89851 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.next
+ key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index 91d975a..d13acd9 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,6 +12,8 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.zx2c4.com/wireguard/conn"
)
const (
@@ -24,10 +26,14 @@ type Peer struct {
keypairs Keypairs
handshake Handshake
device *Device
- endpoint Endpoint
+ endpoint conn.Endpoint
persistentKeepaliveInterval uint16
- // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
+ // These fields are accessed with atomic operations, which must be
+ // 64-bit aligned even on 32-bit platforms. Go guarantees that an
+ // allocated struct will be 64-bit aligned. So we place
+ // atomically-accessed fields up front, so that they can share in
+ // this alignment before smaller fields throw it off.
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@@ -108,7 +114,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
- ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock()
@@ -116,13 +121,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
- // conditionally add
+ // add
- if !ssIsZero {
- device.peers.keyMap[pk] = peer
- } else {
- return nil, nil
- }
+ device.peers.keyMap[pk] = peer
// start peer
@@ -222,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
@@ -253,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}
@@ -291,7 +292,7 @@ func (peer *Peer) Stop() {
var RoamingDisabled bool
-func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
+func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if RoamingDisabled {
return
}
diff --git a/device/peer_test.go b/device/peer_test.go
new file mode 100644
index 0000000..6aa238b
--- /dev/null
+++ b/device/peer_test.go
@@ -0,0 +1,43 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+func checkAlignment(t *testing.T, name string, offset uintptr) {
+ t.Helper()
+ if offset%8 != 0 {
+ t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
+ }
+}
+
+// TestPeerAlignment checks that atomically-accessed fields are
+// aligned to 64-bit boundaries, as required by the atomic package.
+//
+// Unfortunately, violating this rule on 32-bit platforms results in a
+// hard segfault at runtime.
+func TestPeerAlignment(t *testing.T) {
+ var p Peer
+
+ typ := reflect.TypeOf(p)
+ t.Logf("Peer type size: %d, with fields:", typ.Size())
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+ t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
+ field.Name,
+ field.Offset,
+ field.Type.Size(),
+ field.Type.Align(),
+ )
+ }
+
+ checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
+ checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
+}
diff --git a/device/pools.go b/device/pools.go
index 98f4ef1..e778d2e 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index f5c042d..f19c7be 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index cf86ba1..18f0bea 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 589b0aa..4c83015 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/receive.go b/device/receive.go
index 7d0693e..b53c9c0 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,12 +17,13 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
- endpoint Endpoint
+ endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
}
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
packet []byte
counter uint64
keypair *Keypair
- endpoint Endpoint
+ endpoint conn.Endpoint
}
func (elem *QueueInboundElement) Drop() {
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
+func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
logDebug := device.log.Debug
defer func() {
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
var (
err error
size int
- endpoint Endpoint
+ endpoint conn.Endpoint
)
for {
diff --git a/device/send.go b/device/send.go
index 72633be..c0bdba3 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -448,6 +448,21 @@ func (peer *Peer) RoutineNonce() {
}
}
+func calculatePaddingSize(packetSize, mtu int) int {
+ lastUnit := packetSize
+ if mtu == 0 {
+ return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
+ }
+ if lastUnit > mtu {
+ lastUnit %= mtu
+ }
+ paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
+ if paddedSize > mtu {
+ paddedSize = mtu
+ }
+ return paddedSize - lastUnit
+}
+
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
@@ -514,13 +529,8 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16
- mtu := int(atomic.LoadInt32(&device.tun.mtu))
- lastUnit := len(elem.packet) % mtu
- paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
- if paddedSize > mtu {
- paddedSize = mtu
- }
- for i := len(elem.packet); i < paddedSize; i++ {
+ paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
+ for i := 0; i < paddingSize; i++ {
elem.packet = append(elem.packet, 0)
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
new file mode 100644
index 0000000..1cc52f6
--- /dev/null
+++ b/device/sticky_default.go
@@ -0,0 +1,12 @@
+// +build !linux
+
+package device
+
+import (
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ return nil, nil
+}
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
new file mode 100644
index 0000000..1994a70
--- /dev/null
+++ b/device/sticky_linux.go
@@ -0,0 +1,215 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ *
+ * This implements userspace semantics of "sticky sockets", modeled after
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
+ */
+
+package device
+
+import (
+ "sync"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ netlinkSock, err := createNetlinkRouteSocket()
+ if err != nil {
+ return nil, err
+ }
+ netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
+ if err != nil {
+ unix.Close(netlinkSock)
+ return nil, err
+ }
+
+ go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
+
+ return netlinkCancel, nil
+}
+
+func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
+ type peerEndpointPtr struct {
+ peer *Peer
+ endpoint *conn.Endpoint
+ }
+ var reqPeer map[uint32]peerEndpointPtr
+ var reqPeerLock sync.Mutex
+
+ defer unix.Close(netlinkSock)
+
+ for msg := make([]byte, 1<<16); ; {
+ var err error
+ var msgn int
+ for {
+ msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
+ if err == nil || !rwcancel.RetryAfterError(err) {
+ break
+ }
+ if !netlinkCancel.ReadyRead() {
+ return
+ }
+ }
+ if err != nil {
+ return
+ }
+
+ for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+ hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+ if uint(hdr.Len) > uint(len(remain)) {
+ break
+ }
+
+ switch hdr.Type {
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+ if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
+ if uint(len(remain)) < uint(hdr.Len) {
+ break
+ }
+ if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+ attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+ for {
+ if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+ break
+ }
+ attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+ if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+ break
+ }
+ if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+ ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+ reqPeerLock.Lock()
+ if reqPeer == nil {
+ reqPeerLock.Unlock()
+ break
+ }
+ pePtr, ok := reqPeer[hdr.Seq]
+ reqPeerLock.Unlock()
+ if !ok {
+ break
+ }
+ pePtr.peer.Lock()
+ if &pePtr.peer.endpoint != pePtr.endpoint {
+ pePtr.peer.Unlock()
+ break
+ }
+ if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
+ pePtr.peer.Unlock()
+ break
+ }
+ pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
+ pePtr.peer.Unlock()
+ }
+ attr = attr[attrhdr.Len:]
+ }
+ }
+ break
+ }
+ reqPeerLock.Lock()
+ reqPeer = make(map[uint32]peerEndpointPtr)
+ reqPeerLock.Unlock()
+ go func() {
+ device.peers.RLock()
+ i := uint32(1)
+ for _, peer := range device.peers.keyMap {
+ peer.RLock()
+ if peer.endpoint == nil {
+ peer.RUnlock()
+ continue
+ }
+ nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
+ if nativeEP == nil {
+ peer.RUnlock()
+ continue
+ }
+ if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
+ peer.RUnlock()
+ break
+ }
+ nlmsg := struct {
+ hdr unix.NlMsghdr
+ msg unix.RtMsg
+ dsthdr unix.RtAttr
+ dst [4]byte
+ srchdr unix.RtAttr
+ src [4]byte
+ markhdr unix.RtAttr
+ mark uint32
+ }{
+ unix.NlMsghdr{
+ Type: uint16(unix.RTM_GETROUTE),
+ Flags: unix.NLM_F_REQUEST,
+ Seq: i,
+ },
+ unix.RtMsg{
+ Family: unix.AF_INET,
+ Dst_len: 32,
+ Src_len: 32,
+ },
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_DST,
+ },
+ nativeEP.Dst4().Addr,
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_SRC,
+ },
+ nativeEP.Src4().Src,
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_MARK,
+ },
+ uint32(bind.LastMark()),
+ }
+ nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+ reqPeerLock.Lock()
+ reqPeer[i] = peerEndpointPtr{
+ peer: peer,
+ endpoint: &peer.endpoint,
+ }
+ reqPeerLock.Unlock()
+ peer.RUnlock()
+ i++
+ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+ if err != nil {
+ break
+ }
+ }
+ device.peers.RUnlock()
+ }()
+ }
+ remain = remain[hdr.Len:]
+ }
+ }
+}
+
+func createNetlinkRouteSocket() (int, error) {
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ if err != nil {
+ return -1, err
+ }
+ saddr := &unix.SockaddrNetlink{
+ Family: unix.AF_NETLINK,
+ Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+ }
+ err = unix.Bind(sock, saddr)
+ if err != nil {
+ unix.Close(sock)
+ return -1, err
+ }
+ return sock, nil
+}
diff --git a/device/timers.go b/device/timers.go
index 18ee736..0232eef 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
diff --git a/device/tun.go b/device/tun.go
index 0a3fc79..1f88f33 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/tun_test.go b/device/tun_test.go
index 5614771..a2db2a5 100644
--- a/device/tun_test.go
+++ b/device/tun_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/uapi.go b/device/uapi.go
index 999eeb5..9f9c9bd 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,12 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
+ "errors"
"fmt"
"io"
"net"
@@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -30,7 +32,7 @@ func (s IPCError) ErrorCode() int64 {
return s.int64
}
-func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
+func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
@@ -105,7 +107,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
return nil
}
-func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
+func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
scanner := bufio.NewScanner(socket)
logError := device.log.Error
logDebug := device.log.Debug
@@ -138,7 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "private_key":
var sk NoisePrivateKey
- err := sk.FromHex(value)
+ err := sk.FromMaybeZeroHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
@@ -306,7 +308,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
err := func() error {
peer.Lock()
defer peer.Unlock()
- endpoint, err := CreateEndpoint(value)
+ endpoint, err := conn.CreateEndpoint(value)
if err != nil {
return err
}
@@ -420,10 +422,20 @@ func (device *Device) IpcHandle(socket net.Conn) {
switch op {
case "set=1\n":
- status = device.IpcSetOperation(buffered.Reader)
+ err = device.IpcSetOperation(buffered.Reader)
+ if err != nil && !errors.As(err, &status) {
+ // should never happen
+ device.log.Error.Println("Invalid UAPI error:", err)
+ status = &IPCError{1}
+ }
case "get=1\n":
- status = device.IpcGetOperation(buffered.Writer)
+ err = device.IpcGetOperation(buffered.Writer)
+ if err != nil && !errors.As(err, &status) {
+ // should never happen
+ device.log.Error.Println("Invalid UAPI error:", err)
+ status = &IPCError{1}
+ }
default:
device.log.Error.Println("Invalid UAPI operation:", op)
diff --git a/device/version.go b/device/version.go
index 326b9a9..0877595 100644
--- a/device/version.go
+++ b/device/version.go
@@ -1,3 +1,3 @@
package device
-const WireGuardGoVersion = "0.0.20191012"
+const WireGuardGoVersion = "0.0.20200320"
diff --git a/go.mod b/go.mod
index 34b1e72..f264ea0 100644
--- a/go.mod
+++ b/go.mod
@@ -1,10 +1,10 @@
module golang.zx2c4.com/wireguard
-go 1.12
+go 1.13
require (
- golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
- golang.org/x/net v0.0.0-20191003171128-d98b1b443823
- golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
+ golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79
+ golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5
+ golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3
golang.org/x/text v0.3.2
)
diff --git a/go.sum b/go.sum
index 970f4cb..e6b8991 100644
--- a/go.sum
+++ b/go.sum
@@ -1,13 +1,14 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
-golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79 h1:IaQbIIB2X/Mp/DKctl6ROxz1KyMlKp4uyvL6+kQ7C88=
+golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
-golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5 h1:WQ8q63x+f/zpC8Ac1s9wLElVoHhm32p6tudrU72n1QA=
+golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
-golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3 h1:5B6i6EAiSYyejWfvc5Rc9BbI3rzIsrrXfAQBWnYfn+w=
+golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go
index 75cc0e3..ee05cb7 100644
--- a/ipc/uapi_bsd.go
+++ b/ipc/uapi_bsd.go
@@ -2,32 +2,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
- "fmt"
"net"
"os"
- "path"
"unsafe"
"golang.org/x/sys/unix"
)
-var socketDirectory = "/var/run/wireguard"
-
-const (
- IpcErrorIO = -int64(unix.EIO)
- IpcErrorProtocol = -int64(unix.EPROTO)
- IpcErrorInvalid = -int64(unix.EINVAL)
- IpcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketName = "%s.sock"
-)
-
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -84,10 +72,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
unixListener.SetUnlinkOnClose(true)
}
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
+ socketPath := sockPath(name)
// watch for deletion of socket
@@ -146,58 +131,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
-
-func UAPIOpen(name string) (*os.File, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- addr, err := net.ResolveUnixAddr("unix", socketPath)
- if err != nil {
- return nil, err
- }
-
- oldUmask := unix.Umask(0077)
- listener, err := func() (*net.UnixListener, error) {
-
- // initial connection attempt
-
- listener, err := net.ListenUnix("unix", addr)
- if err == nil {
- return listener, nil
- }
-
- // check if socket already active
-
- _, err = net.Dial("unix", socketPath)
- if err == nil {
- return nil, errors.New("unix socket in use")
- }
-
- // cleanup & attempt again
-
- err = os.Remove(socketPath)
- if err != nil {
- return nil, err
- }
- return net.ListenUnix("unix", addr)
- }()
- unix.Umask(oldUmask)
-
- if err != nil {
- return nil, err
- }
-
- return listener.File()
-}
diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go
index a3c95ca..bda19e9 100644
--- a/ipc/uapi_linux.go
+++ b/ipc/uapi_linux.go
@@ -1,31 +1,18 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
- "errors"
- "fmt"
"net"
"os"
- "path"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
-var socketDirectory = "/var/run/wireguard"
-
-const (
- IpcErrorIO = -int64(unix.EIO)
- IpcErrorProtocol = -int64(unix.EPROTO)
- IpcErrorInvalid = -int64(unix.EINVAL)
- IpcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketName = "%s.sock"
-)
-
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
// watch for deletion of socket
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
+ socketPath := sockPath(name)
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
-
-func UAPIOpen(name string) (*os.File, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- addr, err := net.ResolveUnixAddr("unix", socketPath)
- if err != nil {
- return nil, err
- }
-
- oldUmask := unix.Umask(0077)
- listener, err := func() (*net.UnixListener, error) {
-
- // initial connection attempt
-
- listener, err := net.ListenUnix("unix", addr)
- if err == nil {
- return listener, nil
- }
-
- // check if socket already active
-
- _, err = net.Dial("unix", socketPath)
- if err == nil {
- return nil, errors.New("unix socket in use")
- }
-
- // cleanup & attempt again
-
- err = os.Remove(socketPath)
- if err != nil {
- return nil, err
- }
- return net.ListenUnix("unix", addr)
- }()
- unix.Umask(oldUmask)
-
- if err != nil {
- return nil, err
- }
-
- return listener.File()
-}
diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go
new file mode 100644
index 0000000..2e0813e
--- /dev/null
+++ b/ipc/uapi_unix.go
@@ -0,0 +1,63 @@
+// +build linux darwin freebsd openbsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package ipc
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ IpcErrorIO = -int64(unix.EIO)
+ IpcErrorProtocol = -int64(unix.EPROTO)
+ IpcErrorInvalid = -int64(unix.EINVAL)
+ IpcErrorPortInUse = -int64(unix.EADDRINUSE)
+)
+
+var socketDirectory = "/var/run/wireguard"
+
+func sockPath(iface string) string {
+ return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
+}
+
+func UAPIOpen(name string) (*os.File, error) {
+ if err := os.MkdirAll(socketDirectory, 0755); err != nil {
+ return nil, err
+ }
+
+ socketPath := sockPath(name)
+ addr, err := net.ResolveUnixAddr("unix", socketPath)
+ if err != nil {
+ return nil, err
+ }
+
+ oldUmask := unix.Umask(0077)
+ defer unix.Umask(oldUmask)
+
+ listener, err := net.ListenUnix("unix", addr)
+ if err == nil {
+ return listener.File()
+ }
+
+ // Test socket, if not in use cleanup and try again.
+ if _, err := net.Dial("unix", socketPath); err == nil {
+ return nil, errors.New("unix socket in use")
+ }
+ if err := os.Remove(socketPath); err != nil {
+ return nil, err
+ }
+ listener, err = net.ListenUnix("unix", addr)
+ if err != nil {
+ return nil, err
+ }
+ return listener.File()
+}
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index ead0dc5..7fa4f38 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go
index 29d02a7..09f2f1c 100644
--- a/ipc/winpipe/file.go
+++ b/ipc/winpipe/file.go
@@ -3,7 +3,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe
diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go
index 19ac03a..3675af7 100644
--- a/ipc/winpipe/mksyscall.go
+++ b/ipc/winpipe/mksyscall.go
@@ -1,7 +1,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe
diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go
index 06b3037..c587227 100644
--- a/ipc/winpipe/pipe.go
+++ b/ipc/winpipe/pipe.go
@@ -3,7 +3,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe
diff --git a/main.go b/main.go
index 053f488..75c922d 100644
--- a/main.go
+++ b/main.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -41,18 +41,16 @@ func warning() {
return
}
- fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
- fmt.Fprintln(os.Stderr, "W G")
- fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
- fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
- fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
- fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
- fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
- fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
- fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
- fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
- fmt.Fprintln(os.Stderr, "W G")
- fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
+ fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐")
+ fmt.Fprintln(os.Stderr, "│ │")
+ fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │")
+ fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │")
+ fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │")
+ fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │")
+ fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
+ fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │")
+ fmt.Fprintln(os.Stderr, "│ │")
+ fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
}
func main() {
diff --git a/main_windows.go b/main_windows.go
index f57bc8d..291b00d 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package main
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 772c45a..a1dea61 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
@@ -20,21 +20,23 @@ const (
)
type RatelimiterEntry struct {
- sync.Mutex
+ mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
- sync.RWMutex
- stopReset chan struct{}
+ mu sync.RWMutex
+ timeNow func() time.Time
+
+ stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- // stop any ongoing garbage collection routine
+ if rate.timeNow == nil {
+ rate.timeNow = time.Now
+ }
+ // stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
- // start garbage collection routine
+ stopReset := rate.stopReset // store in case Init is called again.
+ // Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
- case _, ok := <-rate.stopReset:
+ case _, ok := <-stopReset:
ticker.Stop()
- if ok {
- ticker = time.NewTicker(time.Second)
- } else {
+ if !ok {
return
}
+ ticker = time.NewTicker(time.Second)
case <-ticker.C:
- func() {
- rate.Lock()
- defer rate.Unlock()
-
- for key, entry := range rate.tableIPv4 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.Unlock()
- }
-
- for key, entry := range rate.tableIPv6 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.Unlock()
- }
-
- if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
- ticker.Stop()
- }
- }()
+ if rate.cleanup() {
+ ticker.Stop()
+ }
}
}
}()
}
+func (rate *Ratelimiter) cleanup() (empty bool) {
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ for key, entry := range rate.tableIPv4 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv4, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ for key, entry := range rate.tableIPv6 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv6, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+}
+
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4()
IPv6 := ip.To16()
- rate.RLock()
+ rate.mu.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6]
}
- rate.RUnlock()
+ rate.mu.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
- entry.lastTime = time.Now()
- rate.Lock()
+ entry.lastTime = rate.timeNow()
+ rate.mu.Lock()
if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{}
}
}
- rate.Unlock()
+ rate.mu.Unlock()
return true
}
// add tokens to entry
- entry.Lock()
- now := time.Now()
+ entry.mu.Lock()
+ now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost {
entry.tokens -= packetCost
- entry.Unlock()
+ entry.mu.Unlock()
return true
}
- entry.Unlock()
+ entry.mu.Unlock()
return false
}
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index a18a097..d1e93fe 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
@@ -11,22 +11,21 @@ import (
"time"
)
-type RatelimiterResult struct {
+type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
+ var rate Ratelimiter
+ var expectedResults []result
- var ratelimiter Ratelimiter
- var expectedResults []RatelimiterResult
-
- Nano := func(nano int64) time.Duration {
+ nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
- Add := func(res RatelimiterResult) {
+ add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- text: "inital burst",
+ text: "initial burst",
})
}
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "after burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
+ wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "not having refilled enough",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+ wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "packet following 2 packet burst",
})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
- ratelimiter.Init()
+ now := time.Now()
+ rate.timeNow = func() time.Time {
+ return now
+ }
+ defer func() {
+ // Lock to avoid data race with cleanup goroutine from Init.
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ rate.timeNow = time.Now
+ }()
+ timeSleep := func(d time.Duration) {
+ now = now.Add(d + 1)
+ rate.cleanup()
+ }
+
+ rate.Init()
+ defer rate.Close()
for i, res := range expectedResults {
- time.Sleep(res.wait)
+ timeSleep(res.wait)
for _, ip := range ips {
- allowed := ratelimiter.Allow(ip)
+ allowed := rate.Allow(ip)
if allowed != res.allowed {
- t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+ t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}
diff --git a/replay/replay.go b/replay/replay.go
index 0f6b6c9..e5c7391 100644
--- a/replay/replay.go
+++ b/replay/replay.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package replay
@@ -21,7 +21,7 @@ const (
const (
CounterRedundantBitsLog = _WordLogSize + 3
CounterRedundantBits = _WordSize * 8
- CounterBitsTotal = 2048
+ CounterBitsTotal = 8192
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
)
diff --git a/replay/replay_test.go b/replay/replay_test.go
index 5365f10..ceae2f3 100644
--- a/replay/replay_test.go
+++ b/replay/replay_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package replay
diff --git a/rwcancel/fdset.go b/rwcancel/fdset.go
index 28746e6..36d0fec 100644
--- a/rwcancel/fdset.go
+++ b/rwcancel/fdset.go
@@ -1,6 +1,8 @@
+// +build !windows
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go
index 808e691..f91a1bf 100644
--- a/rwcancel/rwcancel.go
+++ b/rwcancel/rwcancel.go
@@ -1,8 +1,12 @@
+// +build !windows
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
+// Package rwcancel implements cancelable read/write operations on
+// a file descriptor.
package rwcancel
import (
diff --git a/rwcancel/rwcancel_windows.go b/rwcancel/rwcancel_windows.go
new file mode 100644
index 0000000..0316911
--- /dev/null
+++ b/rwcancel/rwcancel_windows.go
@@ -0,0 +1,8 @@
+// SPDX-License-Identifier: MIT
+
+package rwcancel
+
+type RWCancel struct {
+}
+
+func (*RWCancel) Cancel() {}
diff --git a/rwcancel/select_default.go b/rwcancel/select_default.go
index dd23cda..8cb45d8 100644
--- a/rwcancel/select_default.go
+++ b/rwcancel/select_default.go
@@ -1,8 +1,8 @@
-// +build !linux
+// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
@@ -10,5 +10,6 @@ package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
- return unix.Select(nfd, r, w, e, timeout)
+ _, err := unix.Select(nfd, r, w, e, timeout)
+ return err
}
diff --git a/rwcancel/select_linux.go b/rwcancel/select_linux.go
index 1a72e0a..204d04a 100644
--- a/rwcancel/select_linux.go
+++ b/rwcancel/select_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go
index 565aaa4..2838f4f 100644
--- a/tai64n/tai64n.go
+++ b/tai64n/tai64n.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -17,16 +17,19 @@ const whitenerMask = uint32(0x1000000 - 1)
type Timestamp [TimestampSize]byte
-func Now() Timestamp {
+func stamp(t time.Time) Timestamp {
var tai64n Timestamp
- now := time.Now()
- secs := base + uint64(now.Unix())
- nano := uint32(now.Nanosecond()) &^ whitenerMask
+ secs := base + uint64(t.Unix())
+ nano := uint32(t.Nanosecond()) &^ whitenerMask
binary.BigEndian.PutUint64(tai64n[:], secs)
binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n
}
+func Now() Timestamp {
+ return stamp(time.Now())
+}
+
func (t1 Timestamp) After(t2 Timestamp) bool {
return bytes.Compare(t1[:], t2[:]) > 0
}
diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go
index 859660f..6df7367 100644
--- a/tai64n/tai64n_test.go
+++ b/tai64n/tai64n_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -10,21 +10,31 @@ import (
"time"
)
-/* Testing the essential property of the timestamp
- * as used by WireGuard.
- */
+// Test that timestamps are monotonic as required by Wireguard and that
+// nanosecond-level information is whitened to prevent side channel attacks.
func TestMonotonic(t *testing.T) {
- old := Now()
- for i := 0; i < 50; i++ {
- next := Now()
- if next.After(old) {
- t.Error("Whitening insufficient")
- }
- time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
- next = Now()
- if !next.After(old) {
- t.Error("Not monotonically increasing on whitened nano-second scale")
- }
- old = next
+ startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
+ // Whitening should reduce timestamp granularity
+ // to more than 10 but fewer than 20 milliseconds.
+ tests := []struct {
+ name string
+ t1, t2 time.Time
+ wantAfter bool
+ }{
+ {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
+ {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
+ {"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
+ {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
+ {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
+ got := ts2.After(ts1)
+ if got != tt.wantAfter {
+ t.Errorf("after = %v; want %v", got, tt.wantAfter)
+ }
+ })
}
}
diff --git a/tun/operateonfd.go b/tun/operateonfd.go
index 31747a2..ed7e633 100644
--- a/tun/operateonfd.go
+++ b/tun/operateonfd.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun.go b/tun/tun.go
index 5395bdb..4f6848f 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 6d2e6dd..52b4070 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -11,6 +11,7 @@ import (
"net"
"os"
"syscall"
+ "time"
"unsafe"
"golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32
+func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
+ for i := 0; i < 20; i++ {
+ iface, err = net.InterfaceByIndex(index)
+ if err != nil {
+ if opErr, ok := err.(*net.OpError); ok {
+ if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
+ time.Sleep(time.Duration(i) * time.Second / 3)
+ continue
+ }
+ }
+ }
+ return iface, err
+ }
+ return nil, err
+}
+
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue
}
- iface, err := net.InterfaceByIndex(ifindex)
+ iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index 6cf9313..c312219 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 61902e9..791e0be 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -12,7 +12,6 @@ import (
"bytes"
"errors"
"fmt"
- "net"
"os"
"sync"
"syscall"
@@ -32,14 +31,17 @@ const (
type NativeTun struct {
tunFile *os.File
index int32 // if index
- name string // name of interface
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was pased IFF_NO_PI
+ nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+
+ nameOnce sync.Once // guards calling initNameCache, which sets following fields
+ nameCache string // name of interface
+ nameErr error
}
func (tun *NativeTun) File() *os.File {
@@ -64,14 +66,19 @@ func (tun *NativeTun) routineHackListener() {
}
switch err {
case unix.EINVAL:
+ // If the tunnel is up, it reports that write() is
+ // allowed but we provided invalid data.
tun.events <- EventUp
case unix.EIO:
+ // If the tunnel is down, it reports that no I/O
+ // is possible, without checking our provided data.
tun.events <- EventDown
default:
return
}
select {
case <-time.After(time.Second):
+ // nothing
case <-tun.statusListenersShutdown:
return
}
@@ -85,7 +92,7 @@ func createNetlinkSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
+ Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -126,6 +133,7 @@ func (tun *NativeTun) routineNetlinkListener() {
default:
}
+ wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
@@ -149,10 +157,16 @@ func (tun *NativeTun) routineNetlinkListener() {
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
+ wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
- tun.events <- EventDown
+ // Don't emit EventDown before we've ever emitted EventUp.
+ // This avoids a startup race with HackListener, which
+ // might detect Up before we have finished reporting Down.
+ if wasEverUp {
+ tun.events <- EventDown
+ }
}
tun.events <- EventMTUUpdate
@@ -164,11 +178,6 @@ func (tun *NativeTun) routineNetlinkListener() {
}
}
-func (tun *NativeTun) isUp() (bool, error) {
- inter, err := net.InterfaceByName(tun.name)
- return inter.Flags&net.FlagUp != 0, err
-}
-
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
@@ -198,6 +207,11 @@ func getIFIndex(name string) (int32, error) {
}
func (tun *NativeTun) setMTU(n int) error {
+ name, err := tun.Name()
+ if err != nil {
+ return err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
@@ -212,9 +226,8 @@ func (tun *NativeTun) setMTU(n int) error {
defer unix.Close(fd)
// do ioctl call
-
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
@@ -231,6 +244,11 @@ func (tun *NativeTun) setMTU(n int) error {
}
func (tun *NativeTun) MTU() (int, error) {
+ name, err := tun.Name()
+ if err != nil {
+ return 0, err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
@@ -247,7 +265,7 @@ func (tun *NativeTun) MTU() (int, error) {
// do ioctl call
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
@@ -262,6 +280,15 @@ func (tun *NativeTun) MTU() (int, error) {
}
func (tun *NativeTun) Name() (string, error) {
+ tun.nameOnce.Do(tun.initNameCache)
+ return tun.nameCache, tun.nameErr
+}
+
+func (tun *NativeTun) initNameCache() {
+ tun.nameCache, tun.nameErr = tun.nameSlow()
+}
+
+func (tun *NativeTun) nameSlow() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
@@ -282,13 +309,11 @@ func (tun *NativeTun) Name() (string, error) {
if errno != 0 {
return "", errors.New("failed to get name of TUN device: " + errno.Error())
}
- nullStr := ifr[:]
- i := bytes.IndexByte(nullStr, 0)
- if i != -1 {
- nullStr = nullStr[:i]
+ name := ifr[:]
+ if i := bytes.IndexByte(name, 0); i != -1 {
+ name = name[:i]
}
- tun.name = string(nullStr)
- return tun.name, nil
+ return string(name), nil
}
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
@@ -367,6 +392,9 @@ func (tun *NativeTun) Close() error {
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
if err != nil {
+ if os.IsNotExist(err) {
+ return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
+ }
return nil, err
}
@@ -408,16 +436,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
statusListenersShutdown: make(chan struct{}),
nopi: false,
}
- var err error
- _, err = tun.Name()
+ name, err := tun.Name()
if err != nil {
return nil, err
}
// start event listener
- tun.index, err = getIFIndex(tun.name)
+ tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
}
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index 44cedaa..2003420 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index daad4aa..5a52c56 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2018-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"os"
+ "sync"
"sync/atomic"
"time"
"unsafe"
@@ -35,11 +36,12 @@ type NativeTun struct {
wt *wintun.Interface
handle windows.Handle
close bool
- rings wintun.RingDescriptor
events chan Event
errors chan error
forcedMTU int
rate rateJuggler
+ rings *wintun.RingDescriptor
+ writeLock sync.Mutex
}
const WintunPool = wintun.Pool("WireGuard")
@@ -93,13 +95,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
forcedMTU: forcedMTU,
}
- err = tun.rings.Init()
+ tun.rings, err = wintun.NewRingDescriptor()
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err)
}
- tun.handle, err = tun.wt.Register(&tun.rings)
+ tun.handle, err = tun.wt.Register(tun.rings)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -219,6 +221,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
+ tun.writeLock.Lock()
+ defer tun.writeLock.Unlock()
+
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
new file mode 100644
index 0000000..e94c6d8
--- /dev/null
+++ b/tun/tuntest/tuntest.go
@@ -0,0 +1,150 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package tuntest
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "os"
+
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+func Ping(dst, src net.IP) []byte {
+ localPort := uint16(1337)
+ seq := uint16(0)
+
+ payload := make([]byte, 4)
+ binary.BigEndian.PutUint16(payload[0:], localPort)
+ binary.BigEndian.PutUint16(payload[2:], seq)
+
+ return genICMPv4(payload, dst, src)
+}
+
+// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
+func checksum(buf []byte, initial uint16) uint16 {
+ v := uint32(initial)
+ for i := 0; i < len(buf)-1; i += 2 {
+ v += uint32(binary.BigEndian.Uint16(buf[i:]))
+ }
+ if len(buf)%2 == 1 {
+ v += uint32(buf[len(buf)-1]) << 8
+ }
+ for v > 0xffff {
+ v = (v >> 16) + (v & 0xffff)
+ }
+ return ^uint16(v)
+}
+
+func genICMPv4(payload []byte, dst, src net.IP) []byte {
+ const (
+ icmpv4ProtocolNumber = 1
+ icmpv4Echo = 8
+ icmpv4ChecksumOffset = 2
+ icmpv4Size = 8
+ ipv4Size = 20
+ ipv4TotalLenOffset = 2
+ ipv4ChecksumOffset = 10
+ ttl = 65
+ )
+
+ hdr := make([]byte, ipv4Size+icmpv4Size)
+
+ ip := hdr[0:ipv4Size]
+ icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
+
+ // https://tools.ietf.org/html/rfc792
+ icmpv4[0] = icmpv4Echo // type
+ icmpv4[1] = 0 // code
+ chksum := ^checksum(icmpv4, checksum(payload, 0))
+ binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+
+ // https://tools.ietf.org/html/rfc760 section 3.1
+ length := uint16(len(hdr) + len(payload))
+ ip[0] = (4 << 4) | (ipv4Size / 4)
+ binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
+ ip[8] = ttl
+ ip[9] = icmpv4ProtocolNumber
+ copy(ip[12:], src.To4())
+ copy(ip[16:], dst.To4())
+ chksum = ^checksum(ip[:], 0)
+ binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+
+ var v []byte
+ v = append(v, hdr...)
+ v = append(v, payload...)
+ return []byte(v)
+}
+
+// TODO(crawshaw): find a reusable home for this. package devicetest?
+type ChannelTUN struct {
+ Inbound chan []byte // incoming packets, closed on TUN close
+ Outbound chan []byte // outbound packets, blocks forever on TUN close
+
+ closed chan struct{}
+ events chan tun.Event
+ tun chTun
+}
+
+func NewChannelTUN() *ChannelTUN {
+ c := &ChannelTUN{
+ Inbound: make(chan []byte),
+ Outbound: make(chan []byte),
+ closed: make(chan struct{}),
+ events: make(chan tun.Event, 1),
+ }
+ c.tun.c = c
+ c.events <- tun.EventUp
+ return c
+}
+
+func (c *ChannelTUN) TUN() tun.Device {
+ return &c.tun
+}
+
+type chTun struct {
+ c *ChannelTUN
+}
+
+func (t *chTun) File() *os.File { return nil }
+
+func (t *chTun) Read(data []byte, offset int) (int, error) {
+ select {
+ case <-t.c.closed:
+ return 0, io.EOF // TODO(crawshaw): what is the correct error value?
+ case msg := <-t.c.Outbound:
+ return copy(data[offset:], msg), nil
+ }
+}
+
+// Write is called by the wireguard device to deliver a packet for routing.
+func (t *chTun) Write(data []byte, offset int) (int, error) {
+ if offset == -1 {
+ close(t.c.closed)
+ close(t.c.events)
+ return 0, io.EOF
+ }
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return 0, io.EOF // TODO(crawshaw): what is the correct error value?
+ case t.c.Inbound <- msg:
+ return len(data) - offset, nil
+ }
+}
+
+const DefaultMTU = 1420
+
+func (t *chTun) Flush() error { return nil }
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() chan tun.Event { return t.c.events }
+func (t *chTun) Close() error {
+ t.Write(nil, -1)
+ return nil
+}
diff --git a/tun/wintun/iphlpapi/conversion_windows.go b/tun/wintun/iphlpapi/conversion_windows.go
index a19e961..d2db8a3 100644
--- a/tun/wintun/iphlpapi/conversion_windows.go
+++ b/tun/wintun/iphlpapi/conversion_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
diff --git a/tun/wintun/iphlpapi/mksyscall.go b/tun/wintun/iphlpapi/mksyscall.go
index fc7dba4..8ffc0d4 100644
--- a/tun/wintun/iphlpapi/mksyscall.go
+++ b/tun/wintun/iphlpapi/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go
index f4316fe..302ad45 100644
--- a/tun/wintun/namespace_windows.go
+++ b/tun/wintun/namespace_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package wintun
@@ -59,9 +59,12 @@ func initializeNamespace() error {
if err == windows.ERROR_PATH_NOT_FOUND {
continue
}
+ if err != nil {
+ return fmt.Errorf("OpenPrivateNamespace failed: %v", err)
+ }
}
if err != nil {
- return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
+ return fmt.Errorf("CreatePrivateNamespace failed: %v", err)
}
break
}
diff --git a/tun/wintun/namespaceapi/mksyscall.go b/tun/wintun/namespaceapi/mksyscall.go
index 93d43b0..8ea3085 100644
--- a/tun/wintun/namespaceapi/mksyscall.go
+++ b/tun/wintun/namespaceapi/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
diff --git a/tun/wintun/namespaceapi/namespaceapi_windows.go b/tun/wintun/namespaceapi/namespaceapi_windows.go
index a3a6274..e71077c 100644
--- a/tun/wintun/namespaceapi/namespaceapi_windows.go
+++ b/tun/wintun/namespaceapi/namespaceapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
diff --git a/tun/wintun/nci/mksyscall.go b/tun/wintun/nci/mksyscall.go
index 019da93..129e015 100644
--- a/tun/wintun/nci/mksyscall.go
+++ b/tun/wintun/nci/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package nci
diff --git a/tun/wintun/nci/nci_windows.go b/tun/wintun/nci/nci_windows.go
index 9dc6699..dc9733c 100644
--- a/tun/wintun/nci/nci_windows.go
+++ b/tun/wintun/nci/nci_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package nci
diff --git a/tun/wintun/registry/mksyscall.go b/tun/wintun/registry/mksyscall.go
index 6ad82d2..3e9ff1f 100644
--- a/tun/wintun/registry/mksyscall.go
+++ b/tun/wintun/registry/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package registry
diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go
index 12a0336..6be88fd 100644
--- a/tun/wintun/registry/registry_windows.go
+++ b/tun/wintun/registry/registry_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package registry
diff --git a/tun/wintun/registry/registry_windows_test.go b/tun/wintun/registry/registry_windows_test.go
index c56b51b..2479b3d 100644
--- a/tun/wintun/registry/registry_windows_test.go
+++ b/tun/wintun/registry/registry_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package registry
diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go
index 8f46bc9..4d2fab6 100644
--- a/tun/wintun/ring_windows.go
+++ b/tun/wintun/ring_windows.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
+ "runtime"
"unsafe"
"golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
}
-func (descriptor *RingDescriptor) Init() (err error) {
+func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
+ descriptor = new(RingDescriptor)
+ allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+ if err != nil {
+ return
+ }
+ defer func() {
+ if err != nil {
+ descriptor.free()
+ descriptor = nil
+ }
+ }()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Send.Ring = &Ring{}
+ descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return
}
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Receive.Ring = &Ring{}
+ descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved)
return
}
-
+ runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return
}
+func (descriptor *RingDescriptor) free() {
+ if descriptor.Send.Ring != nil {
+ windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
+ descriptor.Send.Ring = nil
+ descriptor.Receive.Ring = nil
+ }
+}
+
func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved)
diff --git a/tun/wintun/setupapi/mksyscall.go b/tun/wintun/setupapi/mksyscall.go
index ac103a1..234851c 100644
--- a/tun/wintun/setupapi/mksyscall.go
+++ b/tun/wintun/setupapi/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go
index 60a8eb7..a804dd8 100644
--- a/tun/wintun/setupapi/setupapi_windows.go
+++ b/tun/wintun/setupapi/setupapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go
index a9e6b89..b0afbc7 100644
--- a/tun/wintun/setupapi/setupapi_windows_test.go
+++ b/tun/wintun/setupapi/setupapi_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/setupapi/types_windows_386.go b/tun/wintun/setupapi/types32_windows.go
index 132f921..0eaead6 100644
--- a/tun/wintun/setupapi/types_windows_386.go
+++ b/tun/wintun/setupapi/types32_windows.go
@@ -1,6 +1,8 @@
+// +build 386 arm
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/setupapi/types_windows_amd64.go b/tun/wintun/setupapi/types64_windows.go
index d4dd65c..c815b8f 100644
--- a/tun/wintun/setupapi/types_windows_amd64.go
+++ b/tun/wintun/setupapi/types64_windows.go
@@ -1,6 +1,8 @@
+// +build amd64 arm64
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/setupapi/types_windows.go b/tun/wintun/setupapi/types_windows.go
index 136b4be..43e3f39 100644
--- a/tun/wintun/setupapi/types_windows.go
+++ b/tun/wintun/setupapi/types_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr
}
-// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
+// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID
diff --git a/tun/wintun/setupapi/zsetupapi_windows_test.go b/tun/wintun/setupapi/zsetupapi_windows_test.go
index 915b427..5b5f369 100644
--- a/tun/wintun/setupapi/zsetupapi_windows_test.go
+++ b/tun/wintun/setupapi/zsetupapi_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package setupapi
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
index 4c12d97..eed4f01 100644
--- a/tun/wintun/wintun_windows.go
+++ b/tun/wintun/wintun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package wintun
@@ -136,7 +136,7 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
if err != nil {
continue
}
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
+ if !isOurHardwareID(property) {
continue
}
@@ -508,7 +508,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
if err != nil {
continue
}
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
+ if !isOurHardwareID(property) {
continue
}
@@ -801,3 +801,20 @@ func (wintun *Interface) GUID() windows.GUID {
func (wintun *Interface) LUID() uint64 {
return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
}
+
+func isOurHardwareID(property interface{}) bool {
+ hwidLC := strings.ToLower(hardwareID)
+
+ if hwids, ok := property.([]string); ok && len(hwids) > 0 {
+ for i := range hwids {
+ if strings.ToLower(hwids[i]) == hwidLC {
+ return true
+ }
+ }
+ }
+ if hwid, ok := property.(string); ok && strings.ToLower(hwid) == hwidLC {
+ return true
+ }
+
+ return false
+}