aboutsummaryrefslogtreecommitdiffstats
path: root/conn_linux.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-04-20 04:05:11 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-04-20 06:51:28 +0200
commit5ba84696e29c6109e84b1f48247ae02a2bcb106e (patch)
tree66084df3aae2c36bf055bc66f155e2c65e0a6de9 /conn_linux.go
parentCheck for correct first nibble (diff)
downloadwireguard-go-5ba84696e29c6109e84b1f48247ae02a2bcb106e.tar.xz
wireguard-go-5ba84696e29c6109e84b1f48247ae02a2bcb106e.zip
Rework sticky sockets
Diffstat (limited to 'conn_linux.go')
-rw-r--r--conn_linux.go336
1 files changed, 150 insertions, 186 deletions
diff --git a/conn_linux.go b/conn_linux.go
index 8b60d65..88b9ef4 100644
--- a/conn_linux.go
+++ b/conn_linux.go
@@ -1,13 +1,18 @@
-/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+/* Copyright 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
- * WireGuard's kernelspace implementation.
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
*/
package main
import (
- "encoding/binary"
"errors"
"golang.org/x/sys/unix"
"net"
@@ -15,41 +20,46 @@ import (
"unsafe"
)
-/* Supports source address caching
- *
- * Currently there is no way to achieve this within the net package:
- * See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
- */
+type IPv4Source struct {
+ src [4]byte
+ ifindex int32
+}
+
+type IPv6Source struct {
+ src [16]byte
+ //ifindex belongs in dst.ZoneId
+}
+
type NativeEndpoint struct {
- src unix.RawSockaddrInet6
- dst unix.RawSockaddrInet6
+ dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
+ src [unsafe.Sizeof(IPv6Source{})]byte
+ isV6 bool
}
-type NativeBind struct {
- sock4 int
- sock6 int
+func (endpoint *NativeEndpoint) src4() *IPv4Source {
+ return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
-var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = NativeBind{}
+func (endpoint *NativeEndpoint) src6() *IPv6Source {
+ return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
+}
-type IPv4Source struct {
- src unix.RawSockaddrInet4
- Ifindex int32
+func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
+ return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
-func htons(val uint16) uint16 {
- var out [unsafe.Sizeof(val)]byte
- binary.BigEndian.PutUint16(out[:], val)
- return *((*uint16)(unsafe.Pointer(&out[0])))
+func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
+ return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
-func ntohs(val uint16) uint16 {
- tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
- return binary.BigEndian.Uint16((*tmp)[:])
+type NativeBind struct {
+ sock4 int
+ sock6 int
}
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = NativeBind{}
+
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
@@ -59,10 +69,9 @@ func CreateEndpoint(s string) (Endpoint, error) {
ipv4 := addr.IP.To4()
if ipv4 != nil {
- dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
- dst.Family = unix.AF_INET
- dst.Port = htons(uint16(addr.Port))
- dst.Zero = [8]byte{}
+ dst := end.dst4()
+ end.isV6 = false
+ dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
@@ -74,17 +83,16 @@ func CreateEndpoint(s string) (Endpoint, error) {
if err != nil {
return nil, err
}
- dst := &end.dst
- dst.Family = unix.AF_INET6
- dst.Port = htons(uint16(addr.Port))
- dst.Flowinfo = 0
- dst.Scope_id = zone
+ dst := end.dst6()
+ end.isV6 = true
+ dst.Port = addr.Port
+ dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
- return nil, errors.New("Failed to recognize IP address format")
+ return nil, errors.New("Invalid IP address")
}
func CreateBind(port uint16) (Bind, uint16, error) {
@@ -160,86 +168,85 @@ func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
- switch nend.dst.Family {
- case unix.AF_INET6:
- return send6(bind.sock6, nend, buff)
- case unix.AF_INET:
+ if !nend.isV6 {
return send4(bind.sock4, nend, buff)
- default:
- return errors.New("Unknown address family of destination")
+ } else {
+ return send6(bind.sock6, nend, buff)
}
}
-func sockaddrToString(addr unix.RawSockaddrInet6) string {
- var udpAddr net.UDPAddr
-
- switch addr.Family {
- case unix.AF_INET6:
- udpAddr.Port = int(ntohs(addr.Port))
- udpAddr.IP = addr.Addr[:]
- return udpAddr.String()
-
- case unix.AF_INET:
- ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
- udpAddr.Port = int(ntohs(ptr.Port))
- udpAddr.IP = net.IPv4(
- ptr.Addr[0],
- ptr.Addr[1],
- ptr.Addr[2],
- ptr.Addr[3],
- )
- return udpAddr.String()
+func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP {
+ return net.IPv4(
+ addr.Addr[0],
+ addr.Addr[1],
+ addr.Addr[2],
+ addr.Addr[3],
+ )
+}
- default:
- return "<unknown address family>"
- }
+func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP {
+ return addr.Addr[:]
}
-func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
- switch addr.Family {
- case unix.AF_INET6:
- return addr.Addr[:]
- case unix.AF_INET:
- ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+func (end *NativeEndpoint) SrcIP() net.IP {
+ if !end.isV6 {
return net.IPv4(
- ptr.Addr[0],
- ptr.Addr[1],
- ptr.Addr[2],
- ptr.Addr[3],
+ end.src4().src[0],
+ end.src4().src[1],
+ end.src4().src[2],
+ end.src4().src[3],
)
- default:
- return nil
+ } else {
+ return end.src6().src[:]
}
}
-func (end *NativeEndpoint) SrcIP() net.IP {
- return rawAddrToIP(end.src)
-}
-
func (end *NativeEndpoint) DstIP() net.IP {
- return rawAddrToIP(end.dst)
+ if !end.isV6 {
+ return net.IPv4(
+ end.dst4().Addr[0],
+ end.dst4().Addr[1],
+ end.dst4().Addr[2],
+ end.dst4().Addr[3],
+ )
+ } else {
+ return end.dst6().Addr[:]
+ }
}
func (end *NativeEndpoint) DstToBytes() []byte {
- ptr := unsafe.Pointer(&end.src)
- arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
- return arr[:]
+ if !end.isV6 {
+ return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
+ } else {
+ return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
+ }
}
func (end *NativeEndpoint) SrcToString() string {
- return sockaddrToString(end.src)
+ return end.SrcIP().String()
}
func (end *NativeEndpoint) DstToString() string {
- return sockaddrToString(end.dst)
+ var udpAddr net.UDPAddr
+ udpAddr.IP = end.DstIP()
+ if !end.isV6 {
+ udpAddr.Port = end.dst4().Port
+ } else {
+ udpAddr.Port = end.dst6().Port
+ }
+ return udpAddr.String()
}
func (end *NativeEndpoint) ClearDst() {
- end.dst = unix.RawSockaddrInet6{}
+ for i := range end.dst {
+ end.dst[i] = 0
+ }
}
func (end *NativeEndpoint) ClearSrc() {
- end.src = unix.RawSockaddrInet6{}
+ for i := range end.src {
+ end.src[i] = 0
+ }
}
func zoneToUint32(zone string) (uint32, error) {
@@ -295,6 +302,7 @@ func create4(port uint16) (int, uint16, error) {
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
+ return -1, 0, err
}
return fd, uint16(addr.Port), err
@@ -353,140 +361,106 @@ func create6(port uint16) (int, uint16, error) {
}(); err != nil {
unix.Close(fd)
+ return -1, 0, err
}
return fd, uint16(addr.Port), err
}
-func send6(sock int, end *NativeEndpoint, buff []byte) error {
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
- var iovec unix.Iovec
- iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
- iovec.SetLen(len(buff))
-
cmsg := struct {
cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
+ pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
- unix.Inet6Pktinfo{
- Addr: end.src.Addr,
- Ifindex: end.src.Scope_id,
+ unix.Inet4Pktinfo{
+ Spec_dst: end.src4().src,
+ Ifindex: end.src4().ifindex,
},
}
- msghdr := unix.Msghdr{
- Iov: &iovec,
- Iovlen: 1,
- Name: (*byte)(unsafe.Pointer(&end.dst)),
- Namelen: unix.SizeofSockaddrInet6,
- Control: (*byte)(unsafe.Pointer(&cmsg)),
- }
-
- msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
-
- _, _, errno := sendmsg(sock, &msghdr, 0)
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- if errno == 0 {
+ if err == nil {
return nil
}
// clear src and retry
- if errno == unix.EINVAL {
+ if err == unix.EINVAL {
end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- _, _, errno = sendmsg(sock, &msghdr, 0)
+ cmsg.pktinfo = unix.Inet4Pktinfo{}
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
- return errno
+ return err
}
-func send4(sock int, end *NativeEndpoint, buff []byte) error {
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
- var iovec unix.Iovec
- iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
- iovec.SetLen(len(buff))
-
- src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
-
cmsg := struct {
cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
+ pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
- unix.Inet4Pktinfo{
- Spec_dst: src4.src.Addr,
- Ifindex: src4.Ifindex,
+ unix.Inet6Pktinfo{
+ Addr: end.src6().src,
+ Ifindex: end.dst6().ZoneId,
},
}
- msghdr := unix.Msghdr{
- Iov: &iovec,
- Iovlen: 1,
- Name: (*byte)(unsafe.Pointer(&end.dst)),
- Namelen: unix.SizeofSockaddrInet4,
- Control: (*byte)(unsafe.Pointer(&cmsg)),
- Flags: 0,
+ if cmsg.pktinfo.Addr == [16]byte{} {
+ cmsg.pktinfo.Ifindex = 0
}
- msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
-
- _, _, errno := sendmsg(sock, &msghdr, 0)
- // clear source and try again
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- if errno == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- _, _, errno = sendmsg(sock, &msghdr, 0)
+ if err == nil {
+ return nil
}
- // errno = 0 is still an error instance
+ // clear src and retry
- if errno == 0 {
- return nil
+ if err == unix.EINVAL {
+ end.ClearSrc()
+ cmsg.pktinfo = unix.Inet6Pktinfo{}
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
}
- return errno
+ return err
}
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
- var iovec unix.Iovec
- iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
- iovec.SetLen(len(buff))
-
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
- var msghdr unix.Msghdr
- msghdr.Iov = &iovec
- msghdr.Iovlen = 1
- msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
- msghdr.Namelen = unix.SizeofSockaddrInet4
- msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
- msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- size, _, errno := recvmsg(sock, &msghdr, 0)
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = false
- if errno != 0 {
- return 0, errno
+ if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
+ *end.dst4() = *newDst4
}
// update source cache
@@ -494,40 +468,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
- src4.src.Family = unix.AF_INET
- src4.src.Addr = cmsg.pktinfo.Spec_dst
- src4.Ifindex = cmsg.pktinfo.Ifindex
+ end.src4().src = cmsg.pktinfo.Spec_dst
+ end.src4().ifindex = cmsg.pktinfo.Ifindex
}
- return int(size), nil
+ return size, nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
- var iovec unix.Iovec
- iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
- iovec.SetLen(len(buff))
-
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
- var msg unix.Msghdr
- msg.Iov = &iovec
- msg.Iovlen = 1
- msg.Name = (*byte)(unsafe.Pointer(&end.dst))
- msg.Namelen = uint32(unix.SizeofSockaddrInet6)
- msg.Control = (*byte)(unsafe.Pointer(&cmsg))
- msg.SetControllen(int(unsafe.Sizeof(cmsg)))
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- size, _, errno := recvmsg(sock, &msg, 0)
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = true
- if errno != 0 {
- return 0, errno
+ if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
+ *end.dst6() = *newDst6
}
// update source cache
@@ -535,10 +500,9 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src.Family = unix.AF_INET6
- end.src.Addr = cmsg.pktinfo.Addr
- end.src.Scope_id = cmsg.pktinfo.Ifindex
+ end.src6().src = cmsg.pktinfo.Addr
+ end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
- return int(size), nil
+ return size, nil
}