diff options
Diffstat (limited to 'conf/parser.go')
-rw-r--r-- | conf/parser.go | 273 |
1 files changed, 93 insertions, 180 deletions
diff --git a/conf/parser.go b/conf/parser.go index da21e796..b1da9816 100644 --- a/conf/parser.go +++ b/conf/parser.go @@ -1,20 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( "encoding/base64" - "encoding/hex" - "net" + "net/netip" "strconv" "strings" - "time" + "golang.org/x/sys/windows" "golang.org/x/text/encoding/unicode" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/l18n" ) @@ -27,43 +27,16 @@ func (e *ParseError) Error() string { return l18n.Sprintf("%s: %q", e.why, e.offender) } -func parseIPCidr(s string) (ipcidr *IPCidr, err error) { - var addrStr, cidrStr string - var cidr int - - i := strings.IndexByte(s, '/') - if i < 0 { - addrStr = s - } else { - addrStr, cidrStr = s[:i], s[i+1:] - } - - err = &ParseError{l18n.Sprintf("Invalid IP address"), s} - addr := net.ParseIP(addrStr) - if addr == nil { - return +func parseIPCidr(s string) (netip.Prefix, error) { + ipcidr, err := netip.ParsePrefix(s) + if err == nil { + return ipcidr, nil } - maybeV4 := addr.To4() - if maybeV4 != nil { - addr = maybeV4 - } - if len(cidrStr) > 0 { - err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s} - cidr, err = strconv.Atoi(cidrStr) - if err != nil || cidr < 0 || cidr > 128 { - return - } - if cidr > 32 && maybeV4 != nil { - return - } - } else { - if maybeV4 != nil { - cidr = 32 - } else { - cidr = 128 - } + addr, err := netip.ParseAddr(s) + if err != nil { + return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s} } - return &IPCidr{addr, uint8(cidr)}, nil + return netip.PrefixFrom(addr, addr.BitLen()), nil } func parseEndpoint(s string) (*Endpoint, error) { @@ -87,8 +60,8 @@ func parseEndpoint(s string) (*Endpoint, error) { if i := strings.LastIndexByte(host, '%'); i > 1 { end = i } - maybeV6 := net.ParseIP(host[1:end]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { + maybeV6, err2 := netip.ParseAddr(host[1:end]) + if err2 != nil || !maybeV6.Is6() { return nil, err } } else { @@ -96,7 +69,7 @@ func parseEndpoint(s string) (*Endpoint, error) { } host = host[1 : len(host)-1] } - return &Endpoint{host, uint16(port)}, nil + return &Endpoint{host, port}, nil } func parseMTU(s string) (uint16, error) { @@ -135,21 +108,18 @@ func parsePersistentKeepalive(s string) (uint16, error) { return uint16(m), nil } -func parseKeyBase64(s string) (*Key, error) { - k, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} - } - if len(k) != KeyLength { - return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s} +func parseTableOff(s string) (bool, error) { + if s == "off" { + return true, nil + } else if s == "auto" || s == "main" { + return false, nil } - var key Key - copy(key[:], k) - return &key, nil + _, err := strconv.ParseUint(s, 10, 32) + return false, err } -func parseKeyHex(s string) (*Key, error) { - k, err := hex.DecodeString(s) +func parseKeyBase64(s string) (*Key, error) { + k, err := base64.StdEncoding.DecodeString(s) if err != nil { return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} } @@ -161,14 +131,6 @@ func parseKeyHex(s string) (*Key, error) { return &key, nil } -func parseBytesOrStamp(s string) (uint64, error) { - b, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s} - } - return b, nil -} - func splitList(s string) ([]string, error) { var out []string for _, split := range strings.Split(s, ",") { @@ -195,7 +157,7 @@ func (c *Config) maybeAddPeer(p *Peer) { } } -func FromWgQuick(s string, name string) (*Config, error) { +func FromWgQuick(s, name string) (*Config, error) { if !TunnelNameIsValid(name) { return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name} } @@ -205,10 +167,7 @@ func FromWgQuick(s string, name string) (*Config, error) { sawPrivateKey := false var peer *Peer for _, line := range lines { - pound := strings.IndexByte(line, '#') - if pound >= 0 { - line = line[:pound] - } + line, _, _ = strings.Cut(line, "#") line = strings.TrimSpace(line) lineLower := strings.ToLower(line) if len(line) == 0 { @@ -267,7 +226,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - conf.Interface.Addresses = append(conf.Interface.Addresses, *a) + conf.Interface.Addresses = append(conf.Interface.Addresses, a) } case "dns": addresses, err := splitList(val) @@ -275,13 +234,27 @@ func FromWgQuick(s string, name string) (*Config, error) { return nil, err } for _, address := range addresses { - a := net.ParseIP(address) - if a == nil { + a, err := netip.ParseAddr(address) + if err != nil { conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address) } else { conf.Interface.DNS = append(conf.Interface.DNS, a) } } + case "preup": + conf.Interface.PreUp = val + case "postup": + conf.Interface.PostUp = val + case "predown": + conf.Interface.PreDown = val + case "postdown": + conf.Interface.PostDown = val + case "table": + tableOff, err := parseTableOff(val) + if err != nil { + return nil, err + } + conf.Interface.TableOff = tableOff default: return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key} } @@ -309,7 +282,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) } case "persistentkeepalive": p, err := parsePersistentKeepalive(val) @@ -342,7 +315,7 @@ func FromWgQuick(s string, name string) (*Config, error) { return &conf, nil } -func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) { +func FromWgQuickWithUnknownEncoding(s, name string) (*Config, error) { c, firstErr := FromWgQuick(s, name) if firstErr == nil { return c, nil @@ -359,9 +332,7 @@ func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) { return nil, firstErr } -func FromUAPI(s string, existingConfig *Config) (*Config, error) { - lines := strings.Split(s, "\n") - parserState := inInterfaceSection +func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config) *Config { conf := Config{ Name: existingConfig.Name, Interface: Interface{ @@ -369,119 +340,61 @@ func FromUAPI(s string, existingConfig *Config) (*Config, error) { DNS: existingConfig.Interface.DNS, DNSSearch: existingConfig.Interface.DNSSearch, MTU: existingConfig.Interface.MTU, + PreUp: existingConfig.Interface.PreUp, + PostUp: existingConfig.Interface.PostUp, + PreDown: existingConfig.Interface.PreDown, + PostDown: existingConfig.Interface.PostDown, + TableOff: existingConfig.Interface.TableOff, }, } - var peer *Peer - for _, line := range lines { - if len(line) == 0 { - continue + if interfaze.Flags&driver.InterfaceHasPrivateKey != 0 { + conf.Interface.PrivateKey = interfaze.PrivateKey + } + if interfaze.Flags&driver.InterfaceHasListenPort != 0 { + conf.Interface.ListenPort = interfaze.ListenPort + } + var p *driver.Peer + for i := uint32(0); i < interfaze.PeerCount; i++ { + if p == nil { + p = interfaze.FirstPeer() + } else { + p = p.NextPeer() } - equals := strings.IndexByte(line, '=') - if equals < 0 { - return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line} + peer := Peer{} + if p.Flags&driver.PeerHasPublicKey != 0 { + peer.PublicKey = p.PublicKey } - key, val := line[:equals], line[equals+1:] - if len(val) == 0 { - return nil, &ParseError{l18n.Sprintf("Key must have a value"), line} + if p.Flags&driver.PeerHasPresharedKey != 0 { + peer.PresharedKey = p.PresharedKey } - switch key { - case "public_key": - conf.maybeAddPeer(peer) - peer = &Peer{} - parserState = inPeerSection - case "errno": - if val == "0" { - continue - } else { - return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val} - } + if p.Flags&driver.PeerHasEndpoint != 0 { + peer.Endpoint.Port = p.Endpoint.Port() + peer.Endpoint.Host = p.Endpoint.Addr().String() } - if parserState == inInterfaceSection { - switch key { - case "private_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - conf.Interface.PrivateKey = *k - case "listen_port": - p, err := parsePort(val) - if err != nil { - return nil, err - } - conf.Interface.ListenPort = p - case "fwmark": - // Ignored for now. - - default: - return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key} + if p.Flags&driver.PeerHasPersistentKeepalive != 0 { + peer.PersistentKeepalive = p.PersistentKeepalive + } + peer.TxBytes = Bytes(p.TxBytes) + peer.RxBytes = Bytes(p.RxBytes) + if p.LastHandshake != 0 { + peer.LastHandshakeTime = HandshakeTime((p.LastHandshake - 116444736000000000) * 100) + } + var a *driver.AllowedIP + for j := uint32(0); j < p.AllowedIPsCount; j++ { + if a == nil { + a = p.FirstAllowedIP() + } else { + a = a.NextAllowedIP() } - } else if parserState == inPeerSection { - switch key { - case "public_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PublicKey = *k - case "preshared_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PresharedKey = *k - case "protocol_version": - if val != "1" { - return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val} - } - case "allowed_ip": - a, err := parseIPCidr(val) - if err != nil { - return nil, err - } - peer.AllowedIPs = append(peer.AllowedIPs, *a) - case "persistent_keepalive_interval": - p, err := parsePersistentKeepalive(val) - if err != nil { - return nil, err - } - peer.PersistentKeepalive = p - case "endpoint": - e, err := parseEndpoint(val) - if err != nil { - return nil, err - } - peer.Endpoint = *e - case "tx_bytes": - b, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.TxBytes = Bytes(b) - case "rx_bytes": - b, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.RxBytes = Bytes(b) - case "last_handshake_time_sec": - t, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second) - case "last_handshake_time_nsec": - t, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond) - default: - return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key} + var ip netip.Addr + if a.AddressFamily == windows.AF_INET { + ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4])) + } else if a.AddressFamily == windows.AF_INET6 { + ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16])) } + peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr))) } + conf.Peers = append(conf.Peers, peer) } - conf.maybeAddPeer(peer) - - return &conf, nil + return &conf } |