diff options
Diffstat (limited to 'device/uapi.go')
-rw-r--r-- | device/uapi.go | 95 |
1 files changed, 53 insertions, 42 deletions
diff --git a/device/uapi.go b/device/uapi.go index 66ecd48..cc69488 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -12,10 +12,10 @@ import ( "fmt" "io" "net" + "net/netip" "strconv" "strings" "sync" - "sync/atomic" "time" "golang.zx2c4.com/wireguard/ipc" @@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 { return s.code } -func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { +func ipcErrorf(code int64, msg string, args ...any) *IPCError { return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } var byteBufferPool = &sync.Pool{ - New: func() interface{} { return new(bytes.Buffer) }, + New: func() any { return new(bytes.Buffer) }, } // IpcGetOperation implements the WireGuard configuration protocol "get" operation. @@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error { buf := byteBufferPool.Get().(*bytes.Buffer) buf.Reset() defer byteBufferPool.Put(buf) - sendf := func(format string, args ...interface{}) { + sendf := func(format string, args ...any) { fmt.Fprintf(buf, format, args...) buf.WriteByte('\n') } @@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } func() { - // lock required resources device.net.RLock() @@ -98,31 +97,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - // serialize each peer state - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - + // Serialize peer state. + peer.handshake.mutex.RLock() keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() sendf("protocol_version=1") - if peer.endpoint != nil { - sendf("endpoint=%s", peer.endpoint.DstToString()) + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) } + peer.endpoint.Unlock() - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)) - sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) - sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { - sendf("allowed_ip=%s/%d", ip.String(), cidr) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) return true }) } @@ -156,14 +155,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { line := scanner.Text() if line == "" { // Blank line means terminate operation. + peer.handlePostConfig() return nil } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts)) + key, value, ok := strings.Cut(line, "=") + if !ok { + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } - key := parts[0] - value := parts[1] if key == "public_key" { if deviceConfig { @@ -254,10 +252,21 @@ type ipcSetPeer struct { *Peer // Peer is the current peer being operated on dummy bool // dummy reports whether this peer is a temporary, placeholder peer created bool // new reports whether this is a newly created peer + pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on } func (peer *ipcSetPeer) handlePostConfig() { - if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() { + if peer.Peer == nil || peer.dummy { + return + } + if peer.created { + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil + } + if peer.device.isUp() { + peer.Start() + if peer.pkaOn { + peer.SendKeepalive() + } peer.SendStagedPackets() } } @@ -334,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } - peer.Lock() - defer peer.Unlock() - peer.endpoint = endpoint + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) @@ -346,17 +355,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) } - old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) // Send immediate keepalive if we're turning it on and before it wasn't on. - if old == 0 && secs != 0 { - if err != nil { - return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) - } - if device.isUp() && !peer.dummy { - peer.SendKeepalive() - } - } + peer.pkaOn = old == 0 && secs != 0 case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) @@ -369,17 +371,26 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error device.allowedips.RemoveByPeer(peer.Peer) case "allowed_ip": - device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) - - _, network, err := net.ParseCIDR(value) + add := true + verb := "Adding" + if len(value) > 0 && value[0] == '-' { + add = false + verb = "Removing" + value = value[1:] + } + device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb) + prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if peer.dummy { return nil } - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) + if add { + device.allowedips.Insert(prefix, peer.Peer) + } else { + device.allowedips.Remove(prefix, peer.Peer) + } case "protocol_version": if value != "1" { |