diff options
Diffstat (limited to 'device/uapi.go')
-rw-r--r-- | device/uapi.go | 632 |
1 files changed, 324 insertions, 308 deletions
diff --git a/device/uapi.go b/device/uapi.go index 999eeb5..d81dae3 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,43 +1,77 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bufio" + "bytes" + "errors" "fmt" "io" "net" + "net/netip" "strconv" "strings" - "sync/atomic" + "sync" "time" "golang.zx2c4.com/wireguard/ipc" ) type IPCError struct { - int64 + code int64 // error code + err error // underlying/wrapped error } func (s IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) + return fmt.Sprintf("IPC error %d: %v", s.code, s.err) +} + +func (s IPCError) Unwrap() error { + return s.err } func (s IPCError) ErrorCode() int64 { - return s.int64 + return s.code +} + +func ipcErrorf(code int64, msg string, args ...any) *IPCError { + return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } -func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) +var byteBufferPool = &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcGetOperation(w io.Writer) error { + device.ipcMutex.RLock() + defer device.ipcMutex.RUnlock() + + buf := byteBufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer byteBufferPool.Put(buf) + sendf := func(format string, args ...any) { + fmt.Fprintf(buf, format, args...) + buf.WriteByte('\n') + } + keyf := func(prefix string, key *[32]byte) { + buf.Grow(len(key)*2 + 2 + len(prefix)) + buf.WriteString(prefix) + buf.WriteByte('=') + const hex = "0123456789abcdef" + for i := 0; i < len(key); i++ { + buf.WriteByte(hex[key[i]>>4]) + buf.WriteByte(hex[key[i]&0xf]) + } + buf.WriteByte('\n') } func() { - // lock required resources device.net.RLock() @@ -52,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.ToHex()) + keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) } if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) + sendf("listen_port=%d", device.net.port) } if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + sendf("fwmark=%d", device.net.fwmark) } - // serialize each peer state - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - send("protocol_version=1") - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) + // Serialize peer state. + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) } + peer.endpoint.Unlock() - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) - send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - - for _, ip := range device.allowedips.EntriesForPeer(peer) { - send("allowed_ip=" + ip.String()) - } + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ipc.IpcErrorIO} - } + if _, err := w.Write(buf.Bytes()); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } return nil } -func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcSetOperation(r io.Reader) (err error) { + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - var peer *Peer + defer func() { + if err != nil { + device.log.Errorf("%v", err) + } + }() - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true + scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. + peer.handlePostConfig() return nil } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{ipc.IpcErrorProtocol} + key, value, ok := strings.Cut(line, "=") + if !ok { + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipc.IpcErrorInvalid} } + peer.handlePostConfig() + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err + } + } + peer.handlePostConfig() - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - if peer == nil { - dummy = true - peer = &Peer{} - } else { - logDebug.Println(peer, "- UAPI: Created") - } - } + if err := scanner.Err(); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) + } + return nil +} - case "update_only": +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Verbosef("UAPI: Updating private key") + device.SetPrivateKey(sk) - // allow disabling of creation + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set update only, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } + // update port and rebind + device.log.Verbosef("UAPI: Updating listen port") - case "remove": + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() - // remove currently selected peer from device + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true + case "fwmark": + mark, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } - case "preshared_key": + device.log.Verbosef("UAPI: Updating fwmark") + if err := device.BindSetMark(uint32(mark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } - // update PSK + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Verbosef("UAPI: Removing all peers") + device.RemoveAllPeers() - logDebug.Println(peer, "- UAPI: Updating preshared key") + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() + return nil +} - if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer + pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on +} - case "endpoint": +func (peer *ipcSetPeer) handlePostConfig() { + if peer.Peer == nil || peer.dummy { + return + } + if peer.created { + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil + } + if peer.device.isUp() { + peer.Start() + if peer.pkaOn { + peer.SendKeepalive() + } + peer.SendStagedPackets() + } +} - // set endpoint destination +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } - logDebug.Println(peer, "- UAPI: Updating endpoint") + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } - if err != nil { - logError.Println("Failed to set endpoint:", err, ":", value) - return &IPCError{ipc.IpcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipc.IpcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Verbosef("%v - UAPI: Created", peer.Peer) + } + return nil +} - case "replace_allowed_ips": +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } - logDebug.Println(peer, "- UAPI: Removing all allowedips") + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Verbosef("%v - UAPI: Removing", peer.Peer) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true - if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "preshared_key": + device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) - if dummy { - continue - } + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() - device.allowedips.RemoveByPeer(peer) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } - case "allowed_ip": + case "endpoint": + device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) + endpoint, err := device.net.bind.ParseEndpoint(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint - logDebug.Println(peer, "- UAPI: Adding allowedip") + case "persistent_keepalive_interval": + device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipc.IpcErrorInvalid} - } + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } - if dummy { - continue - } + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) + // Send immediate keepalive if we're turning it on and before it wasn't on. + peer.pkaOn = old == 0 && secs != 0 - case "protocol_version": + case "replace_allowed_ips": + device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) - if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "allowed_ip": + device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + prefix, err := netip.ParsePrefix(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + device.allowedips.Insert(prefix, peer.Peer) - default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipc.IpcErrorInvalid} - } + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } return nil } -func (device *Device) IpcHandle(socket net.Conn) { +func (device *Device) IpcGet() (string, error) { + buf := new(strings.Builder) + if err := device.IpcGetOperation(buf); err != nil { + return "", err + } + return buf.String(), nil +} - // create buffered read/writer +func (device *Device) IpcSet(uapiConf string) error { + return device.IpcSetOperation(strings.NewReader(uapiConf)) +} +func (device *Device) IpcHandle(socket net.Conn) { defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -407,35 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) { return bufio.NewReadWriter(reader, writer) }(socket) - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - status = device.IpcSetOperation(buffered.Reader) - - case "get=1\n": - status = device.IpcGetOperation(buffered.Writer) - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } + for { + op, err := buffered.ReadString('\n') + if err != nil { + return + } - // write status + // handle operation + switch op { + case "set=1\n": + err = device.IpcSetOperation(buffered.Reader) + case "get=1\n": + var nextByte byte + nextByte, err = buffered.ReadByte() + if err != nil { + return + } + if nextByte != '\n' { + err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + break + } + err = device.IpcGetOperation(buffered.Writer) + default: + device.log.Errorf("invalid UAPI operation: %v", op) + return + } - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") + // write status + var status *IPCError + if err != nil && !errors.As(err, &status) { + // shouldn't happen + status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) + } + if status != nil { + device.log.Errorf("%v", status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() } } |