From 203554620dc8114de1ff70bb30b80f828e9e26ad Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Thu, 7 Nov 2019 11:13:05 -0500 Subject: conn: introduce new package that splits out the Bind and Endpoint types The sticky socket code stays in the device package for now, as it reaches deeply into the peer list. This is the first step in an effort to split some code out of the very busy device package. Signed-off-by: David Crawshaw --- device/bind_test.go | 14 +- device/bindsocketshim.go | 36 +++ device/boundif_windows.go | 64 ---- device/conn.go | 187 ----------- device/conn_default.go | 178 ----------- device/conn_linux.go | 766 ---------------------------------------------- device/device.go | 146 ++++++++- device/mark_default.go | 12 - device/mark_unix.go | 65 ---- device/peer.go | 6 +- device/receive.go | 9 +- device/sticky_default.go | 12 + device/sticky_linux.go | 215 +++++++++++++ device/uapi.go | 3 +- 14 files changed, 419 insertions(+), 1294 deletions(-) create mode 100644 device/bindsocketshim.go delete mode 100644 device/boundif_windows.go delete mode 100644 device/conn.go delete mode 100644 device/conn_default.go delete mode 100644 device/conn_linux.go delete mode 100644 device/mark_default.go delete mode 100644 device/mark_unix.go create mode 100644 device/sticky_default.go create mode 100644 device/sticky_linux.go (limited to 'device') diff --git a/device/bind_test.go b/device/bind_test.go index 0c2e2cf..c5f7f68 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -5,11 +5,15 @@ package device -import "errors" +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) type DummyDatagram struct { msg []byte - endpoint Endpoint + endpoint conn.Endpoint world bool // better type } @@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") @@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") @@ -50,6 +54,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end Endpoint) error { +func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go new file mode 100644 index 0000000..c4dd4ef --- /dev/null +++ b/device/bindsocketshim.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface4(interfaceIndex, blackhole) + } + return nil +} + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface6(interfaceIndex, blackhole) + } + return nil +} diff --git a/device/boundif_windows.go b/device/boundif_windows.go deleted file mode 100644 index 6908415..0000000 --- a/device/boundif_windows.go +++ /dev/null @@ -1,64 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "encoding/binary" - "errors" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - sockoptIP_UNICAST_IF = 31 - sockoptIPV6_UNICAST_IF = 31 -) - -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, interfaceIndex) - interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - device.net.bind.(*nativeBind).blackhole4 = blackhole - return nil -} - -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - device.net.bind.(*nativeBind).blackhole6 = blackhole - return nil -} diff --git a/device/conn.go b/device/conn.go deleted file mode 100644 index 7b341f6..0000000 --- a/device/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "net" - "strings" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/device/conn_default.go b/device/conn_default.go deleted file mode 100644 index 661f57d..0000000 --- a/device/conn_default.go +++ /dev/null @@ -1,178 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "net" - "os" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // retrieve port - - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func extractErrno(err error) error { - opErr, ok := err.(*net.OpError) - if !ok { - return nil - } - syscallErr, ok := opErr.Err.(*os.SyscallError) - if !ok { - return nil - } - return syscallErr.Err -} - -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { - var err error - var bind nativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole4 { - return nil - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole6 { - return nil - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/device/conn_linux.go b/device/conn_linux.go deleted file mode 100644 index e90b0e3..0000000 --- a/device/conn_linux.go +++ /dev/null @@ -1,766 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ - -package device - -import ( - "errors" - "net" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" -) - -const ( - FD_ERR = -1 -) - -type IPv4Source struct { - src [4]byte - ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type nativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: unix.RTMGRP_IPV4_ROUTE, - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { - var err error - var bind nativeBind - var newPort uint16 - - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if successful - - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - return nil, 0, err - } - } else { - port = newPort - } - - // attempt ipv4 bind, update port if successful - - bind.sock4, newPort, err = create4(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *nativeBind) SetMark(value uint32) error { - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - -func (bind *nativeBind) Close() error { - var err1, err2, err3 error - if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) - } - if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) - } - err3 = bind.netlinkCancel.Cancel() - - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return syscall.EAFNOSUPPORT - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return syscall.EAFNOSUPPORT - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, - }, - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func (bind *nativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_MARK, - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/device/device.go b/device/device.go index 8c08f1c..a9fedea 100644 --- a/device/device.go +++ b/device/device.go @@ -11,15 +11,14 @@ import ( "sync/atomic" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -39,9 +38,10 @@ type Device struct { starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } staticIdentity struct { @@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { + device.state.starting.Add(3) + device.state.stopping.Add(3) go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() } + device.state.starting.Add(2) + device.state.stopping.Add(2) go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() @@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(2) + device.net.stopping.Add(2) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} diff --git a/device/mark_default.go b/device/mark_default.go deleted file mode 100644 index 7de2524..0000000 --- a/device/mark_default.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build !linux,!openbsd,!freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -func (bind *nativeBind) SetMark(mark uint32) error { - return nil -} diff --git a/device/mark_unix.go b/device/mark_unix.go deleted file mode 100644 index 669b328..0000000 --- a/device/mark_unix.go +++ /dev/null @@ -1,65 +0,0 @@ -// +build android openbsd freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "runtime" - - "golang.org/x/sys/unix" -) - -var fwmarkIoctl int - -func init() { - switch runtime.GOOS { - case "linux", "android": - fwmarkIoctl = 36 /* unix.SO_MARK */ - case "freebsd": - fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ - case "openbsd": - fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ - } -} - -func (bind *nativeBind) SetMark(mark uint32) error { - var operr error - if fwmarkIoctl == 0 { - return nil - } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - return nil -} diff --git a/device/peer.go b/device/peer.go index 19434cd..79d4981 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,6 +12,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.zx2c4.com/wireguard/conn" ) const ( @@ -24,7 +26,7 @@ type Peer struct { keypairs Keypairs handshake Handshake device *Device - endpoint Endpoint + endpoint conn.Endpoint persistentKeepaliveInterval uint16 // These fields are accessed with atomic operations, which must be @@ -290,7 +292,7 @@ func (peer *Peer) Stop() { var RoamingDisabled bool -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { if RoamingDisabled { return } diff --git a/device/receive.go b/device/receive.go index 7d0693e..4818d64 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,12 +17,13 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } @@ -33,7 +34,7 @@ type QueueInboundElement struct { packet []byte counter uint64 keypair *Keypair - endpoint Endpoint + endpoint conn.Endpoint } func (elem *QueueInboundElement) Drop() { @@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { +func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { logDebug := device.log.Debug defer func() { @@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { var ( err error size int - endpoint Endpoint + endpoint conn.Endpoint ) for { diff --git a/device/sticky_default.go b/device/sticky_default.go new file mode 100644 index 0000000..1cc52f6 --- /dev/null +++ b/device/sticky_default.go @@ -0,0 +1,12 @@ +// +build !linux + +package device + +import ( + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + return nil, nil +} diff --git a/device/sticky_linux.go b/device/sticky_linux.go new file mode 100644 index 0000000..f9522c2 --- /dev/null +++ b/device/sticky_linux.go @@ -0,0 +1,215 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "sync" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + netlinkSock, err := createNetlinkRouteSocket() + if err != nil { + return nil, err + } + netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) + if err != nil { + unix.Close(netlinkSock) + return nil, err + } + + go device.routineRouteListener(bind, netlinkSock, netlinkCancel) + + return netlinkCancel, nil +} + +func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { + type peerEndpointPtr struct { + peer *Peer + endpoint *conn.Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil { + peer.RUnlock() + continue + } + nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + if nativeEP == nil { + peer.RUnlock() + continue + } + if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + nativeEP.Dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + nativeEP.Src4().Src, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_MARK, + }, + uint32(bind.LastMark()), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil +} diff --git a/device/uapi.go b/device/uapi.go index 72611ab..6cdccd6 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { err := func() error { peer.Lock() defer peer.Unlock() - endpoint, err := CreateEndpoint(value) + endpoint, err := conn.CreateEndpoint(value) if err != nil { return err } -- cgit v1.2.3-59-g8ed1b