aboutsummaryrefslogtreecommitdiffstats
path: root/device/sticky_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/sticky_linux.go')
-rw-r--r--device/sticky_linux.go59
1 files changed, 33 insertions, 26 deletions
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index e3efc86..f23ff02 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,8 +1,6 @@
-// +build !android
-
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@@ -11,7 +9,7 @@
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
+ * So this code remains platform dependent.
*/
package device
@@ -21,11 +19,19 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
+ return nil, nil
+ }
+
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
@@ -41,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
return netlinkCancel, nil
}
-func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
+func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
@@ -49,6 +55,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
+ defer netlinkCancel.Close()
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
@@ -103,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -127,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
- peer.RUnlock()
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -165,26 +172,26 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
- uint32(bind.LastMark()),
+ device.net.fwmark,
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
@@ -200,13 +207,13 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
}
func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+ Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {