diff options
Diffstat (limited to 'device/sticky_linux.go')
-rw-r--r-- | device/sticky_linux.go | 59 |
1 files changed, 33 insertions, 26 deletions
diff --git a/device/sticky_linux.go b/device/sticky_linux.go index e3efc86..f23ff02 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,8 +1,6 @@ -// +build !android - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port @@ -11,7 +9,7 @@ * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. + * So this code remains platform dependent. */ package device @@ -21,11 +19,19 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + if !conn.StdNetSupportsStickySockets { + return nil, nil + } + if _, ok := bind.(*conn.StdNetBind); !ok { + return nil, nil + } + netlinkSock, err := createNetlinkRouteSocket() if err != nil { return nil, err @@ -41,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er return netlinkCancel, nil } -func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { +func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { type peerEndpointPtr struct { peer *Peer endpoint *conn.Endpoint @@ -49,6 +55,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl var reqPeer map[uint32]peerEndpointPtr var reqPeerLock sync.Mutex + defer netlinkCancel.Close() defer unix.Close(netlinkSock) for msg := make([]byte, 1<<16); ; { @@ -103,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -127,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } - if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { - peer.RUnlock() + if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { + peer.endpoint.Unlock() break } nlmsg := struct { @@ -165,26 +172,26 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl Len: 8, Type: unix.RTA_DST, }, - nativeEP.Dst4().Addr, + nativeEP.DstIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_SRC, }, - nativeEP.Src4().Src, + nativeEP.SrcIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_MARK, }, - uint32(bind.LastMark()), + device.net.fwmark, } nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { @@ -200,13 +207,13 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl } func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + Groups: unix.RTMGRP_IPV4_ROUTE, } err = unix.Bind(sock, saddr) if err != nil { |