aboutsummaryrefslogtreecommitdiffstats
path: root/device/uapi.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/uapi.go')
-rw-r--r--device/uapi.go95
1 files changed, 53 insertions, 42 deletions
diff --git a/device/uapi.go b/device/uapi.go
index 66ecd48..cc69488 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,10 +12,10 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
"sync"
- "sync/atomic"
"time"
"golang.zx2c4.com/wireguard/ipc"
@@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
return s.code
}
-func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
var byteBufferPool = &sync.Pool{
- New: func() interface{} { return new(bytes.Buffer) },
+ New: func() any { return new(bytes.Buffer) },
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
@@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer byteBufferPool.Put(buf)
- sendf := func(format string, args ...interface{}) {
+ sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n')
}
@@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
func() {
-
// lock required resources
device.net.RLock()
@@ -98,31 +97,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
- if peer.endpoint != nil {
- sendf("endpoint=%s", peer.endpoint.DstToString())
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
- sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
- sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -156,14 +155,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
if key == "public_key" {
if deviceConfig {
@@ -254,10 +252,21 @@ type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
}
func (peer *ipcSetPeer) handlePostConfig() {
- if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
peer.SendStagedPackets()
}
}
@@ -334,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
- peer.Lock()
- defer peer.Unlock()
- peer.endpoint = endpoint
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
@@ -346,17 +355,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
- old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
- if old == 0 && secs != 0 {
- if err != nil {
- return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
- }
- if device.isUp() && !peer.dummy {
- peer.SendKeepalive()
- }
- }
+ peer.pkaOn = old == 0 && secs != 0
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
@@ -369,17 +371,26 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
- device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ add := true
+ verb := "Adding"
+ if len(value) > 0 && value[0] == '-' {
+ add = false
+ verb = "Removing"
+ value = value[1:]
+ }
+ device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ if add {
+ device.allowedips.Insert(prefix, peer.Peer)
+ } else {
+ device.allowedips.Remove(prefix, peer.Peer)
+ }
case "protocol_version":
if value != "1" {