From 8c34c4cbb3780c433148966a004f5a51aace0f64 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 4 Aug 2017 16:15:53 +0200 Subject: First set of code review patches --- src/config.go | 225 +++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 136 insertions(+), 89 deletions(-) (limited to 'src/config.go') diff --git a/src/config.go b/src/config.go index 72a604f..e2d7f20 100644 --- a/src/config.go +++ b/src/config.go @@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint64(&peer.persistentKeepaliveInterval), )) + for _, ip := range device.routingTable.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } @@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logDebug := device.log.Debug var peer *Peer + + deviceConfig := true + for scanner.Scan() { // parse line @@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } parts := strings.Split(line, "=") if len(parts) != 2 { - return &IPCError{Code: ipcErrorNoKeyValue} + return &IPCError{Code: ipcErrorProtocol} } key := parts[0] value := parts[1] - switch key { + /* device configuration */ - /* interface configuration */ + if deviceConfig { - case "private_key": - var sk NoisePrivateKey - if value == "" { - device.SetPrivateKey(sk) - } else { - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} + switch key { + case "private_key": + var sk NoisePrivateKey + if value == "" { + device.SetPrivateKey(sk) + } else { + err := sk.FromHex(value) + if err != nil { + logError.Println("Failed to set private_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + device.SetPrivateKey(sk) } - device.SetPrivateKey(sk) - } - case "listen_port": - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } - netc := &device.net - netc.mutex.Lock() - if netc.addr.Port != int(port) { - if netc.conn != nil { - netc.conn.Close() + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to set listen_port:", err) + return &IPCError{Code: ipcErrorInvalid} } - netc.addr.Port = int(port) - netc.conn, err = net.ListenUDP("udp", netc.addr) - } - netc.mutex.Unlock() - if err != nil { - logError.Println("Failed to create UDP listener:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } + netc := &device.net + netc.mutex.Lock() + if netc.addr.Port != int(port) { + if netc.conn != nil { + netc.conn.Close() + } + netc.addr.Port = int(port) + netc.conn, err = net.ListenUDP("udp", netc.addr) + } + netc.mutex.Unlock() + if err != nil { + logError.Println("Failed to create UDP listener:", err) + return &IPCError{Code: ipcErrorIO} + } + // TODO: Clear source address of all peers - case "fwmark": - logError.Println("FWMark not handled yet") + case "fwmark": + logError.Println("FWMark not handled yet") + // TODO: Clear source address of all peers - case "public_key": - var pubKey NoisePublicKey - err := pubKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } - device.mutex.RLock() - peer, _ = device.peers[pubKey] - device.mutex.RUnlock() - if peer == nil { - peer = device.NewPeer(pubKey) - } + case "public_key": - case "replace_peers": - if value == "true" { - device.RemoveAllPeers() - } else { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{Code: ipcErrorInvalidValue} - } + // switch to peer configuration - default: + deviceConfig = false - /* peer configuration */ + case "replace_peers": + if value != "true" { + logError.Println("Failed to set replace_peers, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } + device.RemoveAllPeers() - if peer == nil { - logError.Println("No peer referenced, before peer operation") - return &IPCError{Code: ipcErrorNoPeer} + default: + logError.Println("Invalid UAPI key (device configuration):", key) + return &IPCError{Code: ipcErrorInvalid} } + } + + /* peer configuration */ + + if !deviceConfig { switch key { + case "public_key": + var pubKey NoisePublicKey + err := pubKey.FromHex(value) + if err != nil { + logError.Println("Failed to get peer by public_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + // check if public key of peer equal to device + + device.mutex.RLock() + if device.publicKey.Equals(pubKey) { + device.mutex.RUnlock() + logError.Println("Public key of peer matches private key of device") + return &IPCError{Code: ipcErrorInvalid} + } + + // find peer referenced + + peer, _ = device.peers[pubKey] + device.mutex.RUnlock() + if peer == nil { + peer = device.NewPeer(pubKey) + } + case "remove": + if value != "true" { + logError.Println("Failed to set remove, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } device.RemovePeer(peer.handshake.remoteStatic) logDebug.Println("Removing", peer.String()) peer = nil @@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { }() if err != nil { logError.Println("Failed to set preshared_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } case "endpoint": + // TODO: Only IP and port addr, err := net.ResolveUDPAddr("udp", value) if err != nil { logError.Println("Failed to set endpoint:", value) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } peer.mutex.Lock() peer.endpoint = addr peer.mutex.Unlock() case "persistent_keepalive_interval": - secs, err := strconv.ParseInt(value, 10, 64) - if secs < 0 || err != nil { + + // update keep-alive interval + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { logError.Println("Failed to set persistent_keepalive_interval:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } - atomic.StoreUint64( + + old := atomic.SwapUint64( &peer.persistentKeepaliveInterval, - uint64(secs), + secs, ) + // send immediate keep-alive + + if old == 0 && secs != 0 { + up, err := device.tun.IsUp() + if err != nil { + logError.Println("Failed to get tun device status:", err) + return &IPCError{Code: ipcErrorIO} + } + if up { + peer.SendKeepAlive() + } + } + case "replace_allowed_ips": - if value == "true" { - device.routingTable.RemovePeer(peer) - } else { + if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } + device.routingTable.RemovePeer(peer) case "allowed_ip": _, network, err := net.ParseCIDR(value) if err != nil { logError.Println("Failed to set allowed_ip:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } ones, _ := network.Mask.Size() device.routingTable.Insert(network.IP, uint(ones), peer) default: - logError.Println("Invalid UAPI key:", key) - return &IPCError{Code: ipcErrorInvalidKey} + logError.Println("Invalid UAPI key (peer configuration):", key) + return &IPCError{Code: ipcErrorInvalid} } } } @@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcHandle(device *Device, socket net.Conn) { + // create buffered read/writer + defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) { return } - switch op { + // handle operation + var status *IPCError + + switch op { case "set=1\n": device.log.Debug.Println("Config, set operation") - err := ipcSetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - return + status = ipcSetOperation(device, buffered) case "get=1\n": device.log.Debug.Println("Config, get operation") - err := ipcGetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - return + status = ipcGetOperation(device, buffered) default: device.log.Error.Println("Invalid UAPI operation:", op) + return + } + + // write status + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") } } -- cgit v1.2.3-59-g8ed1b