aboutsummaryrefslogtreecommitdiffstats
path: root/device/uapi.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/uapi.go')
-rw-r--r--device/uapi.go632
1 files changed, 324 insertions, 308 deletions
diff --git a/device/uapi.go b/device/uapi.go
index 999eeb5..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,43 +1,77 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
+ "bytes"
+ "errors"
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
- "sync/atomic"
+ "sync"
"time"
"golang.zx2c4.com/wireguard/ipc"
)
type IPCError struct {
- int64
+ code int64 // error code
+ err error // underlying/wrapped error
}
func (s IPCError) Error() string {
- return fmt.Sprintf("IPC error: %d", s.int64)
+ return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
+}
+
+func (s IPCError) Unwrap() error {
+ return s.err
}
func (s IPCError) ErrorCode() int64 {
- return s.int64
+ return s.code
+}
+
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
+ return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
-func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
- lines := make([]string, 0, 100)
- send := func(line string) {
- lines = append(lines, line)
+var byteBufferPool = &sync.Pool{
+ New: func() any { return new(bytes.Buffer) },
+}
+
+// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcGetOperation(w io.Writer) error {
+ device.ipcMutex.RLock()
+ defer device.ipcMutex.RUnlock()
+
+ buf := byteBufferPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer byteBufferPool.Put(buf)
+ sendf := func(format string, args ...any) {
+ fmt.Fprintf(buf, format, args...)
+ buf.WriteByte('\n')
+ }
+ keyf := func(prefix string, key *[32]byte) {
+ buf.Grow(len(key)*2 + 2 + len(prefix))
+ buf.WriteString(prefix)
+ buf.WriteByte('=')
+ const hex = "0123456789abcdef"
+ for i := 0; i < len(key); i++ {
+ buf.WriteByte(hex[key[i]>>4])
+ buf.WriteByte(hex[key[i]&0xf])
+ }
+ buf.WriteByte('\n')
}
func() {
-
// lock required resources
device.net.RLock()
@@ -52,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
- send("private_key=" + device.staticIdentity.privateKey.ToHex())
+ keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
}
if device.net.port != 0 {
- send(fmt.Sprintf("listen_port=%d", device.net.port))
+ sendf("listen_port=%d", device.net.port)
}
if device.net.fwmark != 0 {
- send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
- send("public_key=" + peer.handshake.remoteStatic.ToHex())
- send("preshared_key=" + peer.handshake.presharedKey.ToHex())
- send("protocol_version=1")
- if peer.endpoint != nil {
- send("endpoint=" + peer.endpoint.DstToString())
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
+ keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+ keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
+ sendf("protocol_version=1")
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
- send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
- send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
- send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
- send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
- send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
-
- for _, ip := range device.allowedips.EntriesForPeer(peer) {
- send("allowed_ip=" + ip.String())
- }
+ sendf("last_handshake_time_sec=%d", secs)
+ sendf("last_handshake_time_nsec=%d", nano)
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ })
}
}()
// send lines (does not require resource locks)
-
- for _, line := range lines {
- _, err := socket.WriteString(line + "\n")
- if err != nil {
- return &IPCError{ipc.IpcErrorIO}
- }
+ if _, err := w.Write(buf.Bytes()); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
}
return nil
}
-func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
- scanner := bufio.NewScanner(socket)
- logError := device.log.Error
- logDebug := device.log.Debug
+// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcSetOperation(r io.Reader) (err error) {
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
- var peer *Peer
+ defer func() {
+ if err != nil {
+ device.log.Errorf("%v", err)
+ }
+ }()
- dummy := false
- createdNewPeer := false
+ peer := new(ipcSetPeer)
deviceConfig := true
+ scanner := bufio.NewScanner(r)
for scanner.Scan() {
-
- // parse line
-
line := scanner.Text()
if line == "" {
+ // Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return &IPCError{ipc.IpcErrorProtocol}
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
-
- /* device configuration */
-
- if deviceConfig {
-
- switch key {
- case "private_key":
- var sk NoisePrivateKey
- err := sk.FromHex(value)
- if err != nil {
- logError.Println("Failed to set private_key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Updating private key")
- device.SetPrivateKey(sk)
-
- case "listen_port":
-
- // parse port number
-
- port, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to parse listen_port:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // update port and rebind
-
- logDebug.Println("UAPI: Updating listen port")
-
- device.net.Lock()
- device.net.port = uint16(port)
- device.net.Unlock()
-
- if err := device.BindUpdate(); err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
- case "fwmark":
-
- // parse fwmark field
-
- fwmark, err := func() (uint32, error) {
- if value == "" {
- return 0, nil
- }
- mark, err := strconv.ParseUint(value, 10, 32)
- return uint32(mark), err
- }()
-
- if err != nil {
- logError.Println("Invalid fwmark", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- logDebug.Println("UAPI: Updating fwmark")
-
- if err := device.BindSetMark(uint32(fwmark)); err != nil {
- logError.Println("Failed to update fwmark:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
-
- case "public_key":
- // switch to peer configuration
- logDebug.Println("UAPI: Transition to peer configuration")
+ if key == "public_key" {
+ if deviceConfig {
deviceConfig = false
-
- case "replace_peers":
- if value != "true" {
- logError.Println("Failed to set replace_peers, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Removing all peers")
- device.RemoveAllPeers()
-
- default:
- logError.Println("Invalid UAPI device key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
}
+ peer.handlePostConfig()
+ // Load/create the peer we are now configuring.
+ err := device.handlePublicKeyLine(peer, value)
+ if err != nil {
+ return err
+ }
+ continue
}
- /* peer configuration */
-
- if !deviceConfig {
-
- switch key {
-
- case "public_key":
- var publicKey NoisePublicKey
- err := publicKey.FromHex(value)
- if err != nil {
- logError.Println("Failed to get peer by public key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // ignore peer with public key of device
-
- device.staticIdentity.RLock()
- dummy = device.staticIdentity.publicKey.Equals(publicKey)
- device.staticIdentity.RUnlock()
-
- if dummy {
- peer = &Peer{}
- } else {
- peer = device.LookupPeer(publicKey)
- }
+ var err error
+ if deviceConfig {
+ err = device.handleDeviceLine(key, value)
+ } else {
+ err = device.handlePeerLine(peer, key, value)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ peer.handlePostConfig()
- createdNewPeer = peer == nil
- if createdNewPeer {
- peer, err = device.NewPeer(publicKey)
- if err != nil {
- logError.Println("Failed to create new peer:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if peer == nil {
- dummy = true
- peer = &Peer{}
- } else {
- logDebug.Println(peer, "- UAPI: Created")
- }
- }
+ if err := scanner.Err(); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
+ }
+ return nil
+}
- case "update_only":
+func (device *Device) handleDeviceLine(key, value string) error {
+ switch key {
+ case "private_key":
+ var sk NoisePrivateKey
+ err := sk.FromMaybeZeroHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating private key")
+ device.SetPrivateKey(sk)
- // allow disabling of creation
+ case "listen_port":
+ port, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set update only, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if createdNewPeer && !dummy {
- device.RemovePeer(peer.handshake.remoteStatic)
- peer = &Peer{}
- dummy = true
- }
+ // update port and rebind
+ device.log.Verbosef("UAPI: Updating listen port")
- case "remove":
+ device.net.Lock()
+ device.net.port = uint16(port)
+ device.net.Unlock()
- // remove currently selected peer from device
+ if err := device.BindUpdate(); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set remove, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if !dummy {
- logDebug.Println(peer, "- UAPI: Removing")
- device.RemovePeer(peer.handshake.remoteStatic)
- }
- peer = &Peer{}
- dummy = true
+ case "fwmark":
+ mark, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
+ }
- case "preshared_key":
+ device.log.Verbosef("UAPI: Updating fwmark")
+ if err := device.BindSetMark(uint32(mark)); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
+ }
- // update PSK
+ case "replace_peers":
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
+ }
+ device.log.Verbosef("UAPI: Removing all peers")
+ device.RemoveAllPeers()
- logDebug.Println(peer, "- UAPI: Updating preshared key")
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
+ }
- peer.handshake.mutex.Lock()
- err := peer.handshake.presharedKey.FromHex(value)
- peer.handshake.mutex.Unlock()
+ return nil
+}
- if err != nil {
- logError.Println("Failed to set preshared key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+// An ipcSetPeer is the current state of an IPC set operation on a peer.
+type ipcSetPeer struct {
+ *Peer // Peer is the current peer being operated on
+ dummy bool // dummy reports whether this peer is a temporary, placeholder peer
+ created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
+}
- case "endpoint":
+func (peer *ipcSetPeer) handlePostConfig() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
+ peer.SendStagedPackets()
+ }
+}
- // set endpoint destination
+func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
+ // Load/create the peer we are configuring.
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
+ }
- logDebug.Println(peer, "- UAPI: Updating endpoint")
+ // Ignore peer with the same public key as this device.
+ device.staticIdentity.RLock()
+ peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
+ device.staticIdentity.RUnlock()
- err := func() error {
- peer.Lock()
- defer peer.Unlock()
- endpoint, err := CreateEndpoint(value)
- if err != nil {
- return err
- }
- peer.endpoint = endpoint
- return nil
- }()
+ if peer.dummy {
+ peer.Peer = &Peer{}
+ } else {
+ peer.Peer = device.LookupPeer(publicKey)
+ }
- if err != nil {
- logError.Println("Failed to set endpoint:", err, ":", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- case "persistent_keepalive_interval":
-
- // update persistent keepalive interval
-
- logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
-
- secs, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to set persistent keepalive interval:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- old := peer.persistentKeepaliveInterval
- peer.persistentKeepaliveInterval = uint16(secs)
-
- // send immediate keepalive if we're turning it on and before it wasn't on
-
- if old == 0 && secs != 0 {
- if err != nil {
- logError.Println("Failed to get tun device status:", err)
- return &IPCError{ipc.IpcErrorIO}
- }
- if device.isUp.Get() && !dummy {
- peer.SendKeepalive()
- }
- }
+ peer.created = peer.Peer == nil
+ if peer.created {
+ peer.Peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
+ }
+ device.log.Verbosef("%v - UAPI: Created", peer.Peer)
+ }
+ return nil
+}
- case "replace_allowed_ips":
+func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
+ switch key {
+ case "update_only":
+ // allow disabling of creation
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
+ }
+ if peer.created && !peer.dummy {
+ device.RemovePeer(peer.handshake.remoteStatic)
+ peer.Peer = &Peer{}
+ peer.dummy = true
+ }
- logDebug.Println(peer, "- UAPI: Removing all allowedips")
+ case "remove":
+ // remove currently selected peer from device
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
+ }
+ if !peer.dummy {
+ device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
+ device.RemovePeer(peer.handshake.remoteStatic)
+ }
+ peer.Peer = &Peer{}
+ peer.dummy = true
- if value != "true" {
- logError.Println("Failed to replace allowedips, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "preshared_key":
+ device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
- if dummy {
- continue
- }
+ peer.handshake.mutex.Lock()
+ err := peer.handshake.presharedKey.FromHex(value)
+ peer.handshake.mutex.Unlock()
- device.allowedips.RemoveByPeer(peer)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
+ }
- case "allowed_ip":
+ case "endpoint":
+ device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
+ endpoint, err := device.net.bind.ParseEndpoint(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
+ }
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
- logDebug.Println(peer, "- UAPI: Adding allowedip")
+ case "persistent_keepalive_interval":
+ device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
- _, network, err := net.ParseCIDR(value)
- if err != nil {
- logError.Println("Failed to set allowed ip:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ secs, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
+ }
- if dummy {
- continue
- }
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint(ones), peer)
+ // Send immediate keepalive if we're turning it on and before it wasn't on.
+ peer.pkaOn = old == 0 && secs != 0
- case "protocol_version":
+ case "replace_allowed_ips":
+ device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.RemoveByPeer(peer.Peer)
- if value != "1" {
- logError.Println("Invalid protocol version:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "allowed_ip":
+ device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
+ prefix, err := netip.ParsePrefix(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.Insert(prefix, peer.Peer)
- default:
- logError.Println("Invalid UAPI peer key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "protocol_version":
+ if value != "1" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
+
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
return nil
}
-func (device *Device) IpcHandle(socket net.Conn) {
+func (device *Device) IpcGet() (string, error) {
+ buf := new(strings.Builder)
+ if err := device.IpcGetOperation(buf); err != nil {
+ return "", err
+ }
+ return buf.String(), nil
+}
- // create buffered read/writer
+func (device *Device) IpcSet(uapiConf string) error {
+ return device.IpcSetOperation(strings.NewReader(uapiConf))
+}
+func (device *Device) IpcHandle(socket net.Conn) {
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -407,35 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer)
}(socket)
- defer buffered.Flush()
-
- op, err := buffered.ReadString('\n')
- if err != nil {
- return
- }
-
- // handle operation
-
- var status *IPCError
-
- switch op {
- case "set=1\n":
- status = device.IpcSetOperation(buffered.Reader)
-
- case "get=1\n":
- status = device.IpcGetOperation(buffered.Writer)
-
- default:
- device.log.Error.Println("Invalid UAPI operation:", op)
- return
- }
+ for {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
- // write status
+ // handle operation
+ switch op {
+ case "set=1\n":
+ err = device.IpcSetOperation(buffered.Reader)
+ case "get=1\n":
+ var nextByte byte
+ nextByte, err = buffered.ReadByte()
+ if err != nil {
+ return
+ }
+ if nextByte != '\n' {
+ err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
+ break
+ }
+ err = device.IpcGetOperation(buffered.Writer)
+ default:
+ device.log.Errorf("invalid UAPI operation: %v", op)
+ return
+ }
- if status != nil {
- device.log.Error.Println(status)
- fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
+ // write status
+ var status *IPCError
+ if err != nil && !errors.As(err, &status) {
+ // shouldn't happen
+ status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
+ }
+ if status != nil {
+ device.log.Errorf("%v", status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ buffered.Flush()
}
}