From 019ce9f2815cd21756be4f11702fcb02b5453fdc Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 25 Feb 2019 18:45:32 +0100 Subject: conf: introduce configuration management Signed-off-by: Jason A. Donenfeld --- conf/config.go | 180 ++++++++++++++++ conf/dpapi/dpapi_windows.go | 107 +++++++++ conf/dpapi/dpapi_windows_test.go | 79 +++++++ conf/dpapi/mksyscall.go | 8 + conf/dpapi/zdpapi_windows.go | 68 ++++++ conf/mksyscall.go | 8 + conf/parser.go | 454 +++++++++++++++++++++++++++++++++++++++ conf/parser_test.go | 128 +++++++++++ conf/path_windows.go | 50 +++++ conf/store.go | 199 +++++++++++++++++ conf/store_test.go | 91 ++++++++ conf/storewatcher.go | 38 ++++ conf/storewatcher_windows.go | 59 +++++ conf/writer.go | 125 +++++++++++ conf/zsyscall_windows.go | 96 +++++++++ 15 files changed, 1690 insertions(+) create mode 100644 conf/config.go create mode 100644 conf/dpapi/dpapi_windows.go create mode 100644 conf/dpapi/dpapi_windows_test.go create mode 100644 conf/dpapi/mksyscall.go create mode 100644 conf/dpapi/zdpapi_windows.go create mode 100644 conf/mksyscall.go create mode 100644 conf/parser.go create mode 100644 conf/parser_test.go create mode 100644 conf/path_windows.go create mode 100644 conf/store.go create mode 100644 conf/store_test.go create mode 100644 conf/storewatcher.go create mode 100644 conf/storewatcher_windows.go create mode 100644 conf/writer.go create mode 100644 conf/zsyscall_windows.go (limited to 'conf') diff --git a/conf/config.go b/conf/config.go new file mode 100644 index 00000000..a321bc0c --- /dev/null +++ b/conf/config.go @@ -0,0 +1,180 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "strings" + "time" + + "golang.org/x/crypto/curve25519" +) + +const KeyLength = 32 + +type IPCidr struct { + IP net.IP + Cidr uint8 +} + +type Endpoint struct { + Host string + Port uint16 +} + +type Key [KeyLength]byte +type HandshakeTime time.Duration +type Bytes uint64 + +type Config struct { + Name string + Interface Interface + Peers []Peer +} + +type Interface struct { + PrivateKey Key + Addresses []IPCidr + ListenPort uint16 + Mtu uint16 + Dns []net.IP +} + +type Peer struct { + PublicKey Key + PresharedKey Key + AllowedIPs []IPCidr + Endpoint Endpoint + PersistentKeepalive uint16 + + RxBytes Bytes + TxBytes Bytes + LastHandshakeTime HandshakeTime +} + +func (r *IPCidr) String() string { + return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr) +} + +func (e *Endpoint) String() string { + if strings.IndexByte(e.Host, ':') > 0 { + return fmt.Sprintf("[%s]:%d", e.Host, e.Port) + } + return fmt.Sprintf("%s:%d", e.Host, e.Port) +} + +func (e *Endpoint) IsEmpty() bool { + return len(e.Host) == 0 +} + +func (k *Key) String() string { + return base64.StdEncoding.EncodeToString(k[:]) +} + +func (k *Key) HexString() string { + return hex.EncodeToString(k[:]) +} + +func (k *Key) IsZero() bool { + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k *Key) Public() *Key { + var p [KeyLength]byte + curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k)) + return (*Key)(&p) +} + +func NewPresharedKey() (*Key, error) { + var k [KeyLength]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return (*Key)(&k), nil +} + +func NewPrivateKey() (*Key, error) { + k, err := NewPresharedKey() + if err != nil { + return nil, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return k, nil +} + +func formatInterval(i int64, n string, l int) string { + r := "" + if l > 0 { + r += ", " + } + r += fmt.Sprintf("%d %s", i, n) + if i != 1 { + r += "s" + } + return r +} + +func (t HandshakeTime) IsEmpty() bool { + return t == HandshakeTime(0) +} + +func (t HandshakeTime) String() string { + u := time.Unix(0, 0).Add(time.Duration(t)).Unix() + n := time.Now().Unix() + if u == n { + return "Now" + } else if u > n { + return "System clock wound backward!" + } + left := n - u + years := left / (365 * 24 * 60 * 60) + left = left % (365 * 24 * 60 * 60) + days := left / (24 * 60 * 60) + left = left % (24 * 60 * 60) + hours := left / (60 * 60) + left = left % (60 * 60) + minutes := left / 60 + seconds := left % 60 + s := "" + if years > 0 { + s += formatInterval(years, "year", len(s)) + } + if days > 0 { + s += formatInterval(days, "day", len(s)) + } + if hours > 0 { + s += formatInterval(hours, "hour", len(s)) + } + if minutes > 0 { + s += formatInterval(minutes, "minute", len(s)) + } + if seconds > 0 { + s += formatInterval(seconds, "second", len(s)) + } + s += " ago" + return s +} + +func (b Bytes) String() string { + if b < 1024 { + return fmt.Sprintf("%d B", b) + } else if b < 1024*1024 { + return fmt.Sprintf("%.2f KiB", float64(b)/1024) + } else if b < 1024*1024*1024 { + return fmt.Sprintf("%.2f MiB", float64(b)/(1024*1024)) + } else if b < 1024*1024*1024*1024 { + return fmt.Sprintf("%.2f GiB", float64(b)/(1024*1024*1024)) + } + return fmt.Sprintf("%.2f TiB", float64(b)/(1024*1024*1024)/1024) +} diff --git a/conf/dpapi/dpapi_windows.go b/conf/dpapi/dpapi_windows.go new file mode 100644 index 00000000..03a5d8a3 --- /dev/null +++ b/conf/dpapi/dpapi_windows.go @@ -0,0 +1,107 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package dpapi + +import ( + "errors" + "golang.org/x/sys/windows" + "runtime" + "unsafe" +) + +const ( + dpCRYPTPROTECT_UI_FORBIDDEN uint32 = 0x1 + dpCRYPTPROTECT_LOCAL_MACHINE uint32 = 0x4 + dpCRYPTPROTECT_CRED_SYNC uint32 = 0x8 + dpCRYPTPROTECT_AUDIT uint32 = 0x10 + dpCRYPTPROTECT_NO_RECOVERY uint32 = 0x20 + dpCRYPTPROTECT_VERIFY_PROTECTION uint32 = 0x40 + dpCRYPTPROTECT_CRED_REGENERATE uint32 = 0x80 +) + +type dpBlob struct { + len uint32 + data uintptr +} + +func bytesToBlob(bytes []byte) *dpBlob { + blob := &dpBlob{} + blob.len = uint32(len(bytes)) + if len(bytes) > 0 { + blob.data = uintptr(unsafe.Pointer(&bytes[0])) + } + return blob +} + +//sys cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptProtectData + +func Encrypt(data []byte, name string) ([]byte, error) { + out := dpBlob{} + err := cryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out) + if err != nil { + return nil, errors.New("Unable to encrypt DPAPI protected data: " + err.Error()) + } + + outSlice := *(*[]byte)(unsafe.Pointer(&(struct { + addr uintptr + len int + cap int + }{out.data, int(out.len), int(out.len)}))) + ret := make([]byte, len(outSlice)) + copy(ret, outSlice) + windows.LocalFree(windows.Handle(out.data)) + + return ret, nil +} + +//sys cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptUnprotectData + +func Decrypt(data []byte, name string) ([]byte, error) { + out := dpBlob{} + var outName *uint16 + utf16Name, err := windows.UTF16PtrFromString(name) + if err != nil { + return nil, err + } + + err = cryptUnprotectData(bytesToBlob(data), &outName, nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out) + if err != nil { + return nil, errors.New("Unable to decrypt DPAPI protected data: " + err.Error()) + } + + outSlice := *(*[]byte)(unsafe.Pointer(&(struct { + addr uintptr + len int + cap int + }{out.data, int(out.len), int(out.len)}))) + ret := make([]byte, len(outSlice)) + copy(ret, outSlice) + windows.LocalFree(windows.Handle(out.data)) + + // Note: this ridiculous open-coded strcmp is not constant time. + different := false + a := outName + b := utf16Name + for { + if *a != *b { + different = true + break + } + if *a == 0 || *b == 0 { + break + } + a = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(a)) + 2)) + b = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(b)) + 2)) + } + runtime.KeepAlive(utf16Name) + windows.LocalFree(windows.Handle(unsafe.Pointer(outName))) + + if different { + return nil, errors.New("The input name does not match the stored name") + } + + return ret, nil +} diff --git a/conf/dpapi/dpapi_windows_test.go b/conf/dpapi/dpapi_windows_test.go new file mode 100644 index 00000000..e0e9b42d --- /dev/null +++ b/conf/dpapi/dpapi_windows_test.go @@ -0,0 +1,79 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package dpapi + +import ( + "bytes" + "golang.org/x/sys/windows" + "testing" + "unsafe" +) + +func TestRoundTrip(t *testing.T) { + name := "golang test" + original := []byte("The quick brown fox jumped over the lazy dog") + + e, err := Encrypt(original, name) + if err != nil { + t.Errorf("Error encrypting: %s", err.Error()) + } + + if len(e) < len(original) { + t.Error("Encrypted data is smaller than original data.") + } + + d, err := Decrypt(e, name) + if err != nil { + t.Errorf("Error decrypting: %s", err.Error()) + } + + if !bytes.Equal(d, original) { + t.Error("Decrypted content does not match original") + } + + _, err = Decrypt(e, "bad name") + if err == nil { + t.Error("Decryption failed to notice ad mismatch") + } + + eCorrupt := make([]byte, len(e)) + copy(eCorrupt, e) + eCorrupt[len(original)-1] = 7 + _, err = Decrypt(eCorrupt, name) + if err == nil { + t.Error("Decryption failed to notice ciphertext corruption") + } + + copy(eCorrupt, e) + nameUtf16, err := windows.UTF16FromString(name) + if err != nil { + t.Errorf("Unable to get utf16 chars for name: %s", err) + } + nameUtf16Bytes := *(*[]byte)(unsafe.Pointer(&struct { + addr *byte + len int + cap int + }{(*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16) * 2, cap(nameUtf16) * 2})) + i := bytes.Index(eCorrupt, nameUtf16Bytes) + if i == -1 { + t.Error("Unable to find ad in blob") + } else { + eCorrupt[i] = 7 + _, err = Decrypt(eCorrupt, name) + if err == nil { + t.Error("Decryption failed to notice ad corruption") + } + } + + // BUG: Actually, Windows doesn't report length extension of the buffer, unfortunately. + // + // eCorrupt = make([]byte, len(e)+1) + // copy(eCorrupt, e) + // _, err = Decrypt(eCorrupt, name) + // if err == nil { + // t.Error("Decryption failed to notice length extension") + // } +} diff --git a/conf/dpapi/mksyscall.go b/conf/dpapi/mksyscall.go new file mode 100644 index 00000000..f80c3fd2 --- /dev/null +++ b/conf/dpapi/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package dpapi + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zdpapi_windows.go dpapi_windows.go diff --git a/conf/dpapi/zdpapi_windows.go b/conf/dpapi/zdpapi_windows.go new file mode 100644 index 00000000..e48d36b2 --- /dev/null +++ b/conf/dpapi/zdpapi_windows.go @@ -0,0 +1,68 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package dpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modcrypt32 = windows.NewLazySystemDLL("crypt32.dll") + + procCryptProtectData = modcrypt32.NewProc("CryptProtectData") + procCryptUnprotectData = modcrypt32.NewProc("CryptUnprotectData") +) + +func cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) { + r1, _, e1 := syscall.Syscall9(procCryptProtectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) { + r1, _, e1 := syscall.Syscall9(procCryptUnprotectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} diff --git a/conf/mksyscall.go b/conf/mksyscall.go new file mode 100644 index 00000000..2bdb0204 --- /dev/null +++ b/conf/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go path_windows.go storewatcher_windows.go diff --git a/conf/parser.go b/conf/parser.go new file mode 100644 index 00000000..6a397a9d --- /dev/null +++ b/conf/parser.go @@ -0,0 +1,454 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" + "time" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: ā€˜%sā€™", e.why, e.offender) +} + +func parseIPCidr(s string) (ipcidr *IPCidr, err error) { + var addrStr, cidrStr string + var cidr int + + i := strings.IndexByte(s, '/') + if i < 0 { + addrStr = s + } else { + addrStr, cidrStr = s[:i], s[i+1:] + } + + err = &ParseError{"Invalid IP address", s} + addr := net.ParseIP(addrStr) + if addr == nil { + return + } + if len(cidrStr) > 0 { + err = &ParseError{"Invalid network prefix length", s} + cidr, err = strconv.Atoi(cidrStr) + if err != nil || cidr < 0 || cidr > 128 { + return + } + if cidr > 32 && addr.To4() != nil { + return + } + } else { + if addr.To4() != nil { + cidr = 32 + } else { + cidr = 128 + } + } + return &IPCidr{addr, uint8(cidr)}, nil +} + +func parseEndpoint(s string) (*Endpoint, error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return nil, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return nil, &ParseError{"Invalid endpoint host", host} + } + port, err := parsePort(portStr) + if err != nil { + return nil, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return nil, err + } + } else { + return nil, err + } + host = host[1 : len(host)-1] + } + return &Endpoint{host, uint16(port)}, nil +} + +func parseMTU(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 576 || m > 65535 { + return 0, &ParseError{"Invalid MTU", s} + } + return uint16(m), nil +} + +func parsePort(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid port", s} + } + return uint16(m), nil +} + +func parsePersistentKeepalive(s string) (uint16, error) { + if s == "off" { + return 0, nil + } + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid persistent keepalive", s} + } + return uint16(m), nil +} + +func parseKeyBase64(s string) (*Key, error) { + k, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeyLength { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseKeyHex(s string) (*Key, error) { + k, err := hex.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeyLength { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseBytesOrStamp(s string) (uint64, error) { + b, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s} + } + return b, nil +} + +func splitList(s string) ([]string, error) { + var out []string + for _, split := range strings.Split(s, ",") { + trim := strings.TrimSpace(split) + if len(trim) == 0 { + return nil, &ParseError{"Two commas in a row", s} + } + out = append(out, trim) + } + return out, nil +} + +type parserState int + +const ( + inInterfaceSection parserState = iota + inPeerSection + notInASection +) + +func (c *Config) maybeAddPeer(p *Peer) { + if p != nil { + c.Peers = append(c.Peers, *p) + } +} + +func FromWgQuick(s string, name string) (*Config, error) { + lines := strings.Split(s, "\n") + parserState := notInASection + conf := Config{Name: name} + sawPrivateKey := false + var peer *Peer + for _, line := range lines { + pound := strings.IndexByte(line, '#') + if pound >= 0 { + line = line[:pound] + } + line = strings.TrimSpace(line) + lineLower := strings.ToLower(line) + if len(line) == 0 { + continue + } + if lineLower == "[interface]" { + conf.maybeAddPeer(peer) + parserState = inInterfaceSection + continue + } + if lineLower == "[peer]" { + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + continue + } + if parserState == notInASection { + return nil, &ParseError{"Line must occur in a section", line} + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + if parserState == inInterfaceSection { + switch key { + case "privatekey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + conf.Interface.PrivateKey = *k + sawPrivateKey = true + case "listenport": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.Interface.ListenPort = p + case "mtu": + m, err := parseMTU(val) + if err != nil { + return nil, err + } + conf.Interface.Mtu = m + case "address": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := parseIPCidr(address) + if err != nil { + return nil, err + } + conf.Interface.Addresses = append(conf.Interface.Addresses, *a) + } + case "dns": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a := net.ParseIP(address) + if a == nil { + return nil, &ParseError{"Invalid IP address", address} + } + conf.Interface.Dns = append(conf.Interface.Dns, a) + } + default: + return nil, &ParseError{"Invalid key for [Interface] section", key} + } + } else if parserState == inPeerSection { + switch key { + case "publickey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "presharedkey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PresharedKey = *k + case "allowedips": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := parseIPCidr(address) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + } + case "persistentkeepalive": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + peer.Endpoint = *e + default: + return nil, &ParseError{"Invalid key for [Peer] section", key} + } + } + } + conf.maybeAddPeer(peer) + + if !sawPrivateKey { + return nil, &ParseError{"An interface must have a private key", "[none specified]"} + } + for _, p := range conf.Peers { + if p.PublicKey.IsZero() { + return nil, &ParseError{"All peers must have public keys", "[none specified]"} + } + } + + return &conf, nil +} + +func FromUAPI(s string, existingConfig *Config) (*Config, error) { + lines := strings.Split(s, "\n") + parserState := inInterfaceSection + conf := Config{ + Name: existingConfig.Name, + Interface: Interface{ + Addresses: existingConfig.Interface.Addresses, + Dns: existingConfig.Interface.Dns, + Mtu: existingConfig.Interface.Mtu, + }, + } + var peer *Peer + for _, line := range lines { + if len(line) == 0 { + continue + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := line[:equals], line[equals+1:] + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + switch key { + case "public_key": + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + case "errno": + if val == "0" { + continue + } else { + return nil, &ParseError{"Error in getting configuration", val} + } + } + if parserState == inInterfaceSection { + switch key { + case "private_key": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + conf.Interface.PrivateKey = *k + case "listen_port": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.Interface.ListenPort = p + case "fwmark": + // Ignored for now. + + default: + return nil, &ParseError{"Invalid key for interface section", key} + } + } else if parserState == inPeerSection { + switch key { + case "public_key": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "preshared_key": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PresharedKey = *k + case "protocol_version": + if val != "1" { + return nil, &ParseError{"Protocol version must be 1", val} + } + case "allowed_ip": + a, err := parseIPCidr(val) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + case "persistent_keepalive_interval": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + peer.Endpoint = *e + case "tx_bytes": + b, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.TxBytes = Bytes(b) + case "rx_bytes": + b, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.RxBytes = Bytes(b) + case "last_handshake_time_sec": + t, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second) + case "last_handshake_time_nsec": + t, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond) + default: + return nil, &ParseError{"Invalid key for peer section", key} + } + } + } + conf.maybeAddPeer(peer) + + return &conf, nil +} diff --git a/conf/parser_test.go b/conf/parser_test.go new file mode 100644 index 00000000..a6afbf53 --- /dev/null +++ b/conf/parser_test.go @@ -0,0 +1,128 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "net" + "reflect" + "runtime" + "testing" +) + +const testInput = ` +[Interface] +Address = 10.192.122.1/24 +Address = 10.10.0.1/16 +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +ListenPort = 51820 #comments don't matter + +[Peer] +PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +Endpoint = 192.95.5.67:1234 +AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = [2607:5300:60:6b0::c05f:543]:2468 +AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 +PersistentKeepalive = 100 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = test.wireguard.com:18981 +AllowedIPs = 10.10.10.230/32` + +func noError(t *testing.T, err error) bool { + if err == nil { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error at %s:%d: %#v", fn, line, err) + return false +} + +func equal(t *testing.T, expected, actual interface{}) bool { + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func lenTest(t *testing.T, actualO interface{}, expected int) bool { + actual := reflect.ValueOf(actualO).Len() + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func contains(t *testing.T, list, element interface{}) bool { + listValue := reflect.ValueOf(list) + for i := 0; i < listValue.Len(); i++ { + if reflect.DeepEqual(listValue.Index(i).Interface(), element) { + return true + } + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) + return false +} + +func TestFromWgQuick(t *testing.T) { + conf, err := FromWgQuick(testInput, "test") + if noError(t, err) { + + lenTest(t, conf.Interface.Addresses, 2) + contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)}) + contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)}) + equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String()) + equal(t, uint16(51820), conf.Interface.ListenPort) + + lenTest(t, conf.Peers, 3) + lenTest(t, conf.Peers[0].AllowedIPs, 2) + equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoint) + equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.String()) + + lenTest(t, conf.Peers[1].AllowedIPs, 2) + equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoint) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.String()) + equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) + + lenTest(t, conf.Peers[2].AllowedIPs, 1) + equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoint) + equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.String()) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.String()) + } +} + +func TestParseEndpoint(t *testing.T) { + _, err := parseEndpoint("[192.168.42.0:]:51880") + if err == nil { + t.Error("Error was expected") + } + e, err := parseEndpoint("192.168.42.0:51880") + if noError(t, err) { + equal(t, "192.168.42.0", e.Host) + equal(t, uint16(51880), e.Port) + } + e, err = parseEndpoint("test.wireguard.com:18981") + if noError(t, err) { + equal(t, "test.wireguard.com", e.Host) + equal(t, uint16(18981), e.Port) + } + e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") + if noError(t, err) { + equal(t, "2607:5300:60:6b0::c05f:543", e.Host) + equal(t, uint16(2468), e.Port) + } + _, err = parseEndpoint("[::::::invalid:18981") + if err == nil { + t.Error("Error was expected") + } +} diff --git a/conf/path_windows.go b/conf/path_windows.go new file mode 100644 index 00000000..0ee3fc73 --- /dev/null +++ b/conf/path_windows.go @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "golang.org/x/sys/windows" + "os" + "path/filepath" + "unsafe" +) + +//sys coTaskMemFree(pointer uintptr) = ole32.CoTaskMemFree +//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) [failretval!=0] = shell32.SHGetKnownFolderPath +var folderIDLocalAppData = windows.GUID{0xf1b32785, 0x6fba, 0x4fcf, [8]byte{0x9d, 0x55, 0x7b, 0x8e, 0x7f, 0x15, 0x70, 0x91}} + +const kfFlagCreate = 0x00008000 + +var cachedConfigFileDir string + +func resolveConfigFileDir() (string, error) { + if cachedConfigFileDir != "" { + return cachedConfigFileDir, nil + } + processToken, err := windows.OpenCurrentProcessToken() + if err != nil { + return "", err + } + defer processToken.Close() + var path *uint16 + err = shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, windows.Handle(processToken), &path) + if err != nil { + return "", err + } + defer coTaskMemFree(uintptr(unsafe.Pointer(path))) + root := windows.UTF16ToString((*[windows.MAX_LONG_PATH + 1]uint16)(unsafe.Pointer(path))[:]) + if len(root) == 0 { + return "", errors.New("Unable to determine configuration directory") + } + c := filepath.Join(root, "WireGuard", "Configurations") + err = os.MkdirAll(c, os.ModeDir|0700) + if err != nil { + return "", err + } + cachedConfigFileDir = c + return cachedConfigFileDir, nil +} diff --git a/conf/store.go b/conf/store.go new file mode 100644 index 00000000..7c110865 --- /dev/null +++ b/conf/store.go @@ -0,0 +1,199 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "golang.zx2c4.com/wireguard/windows/conf/dpapi" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +const configFileSuffix = ".conf.dpapi" +const configFileUnencryptedSuffix = ".conf" + +func ListConfigNames() ([]string, error) { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return nil, err + } + files, err := ioutil.ReadDir(configFileDir) + if err != nil { + return nil, err + } + configs := make([]string, len(files)) + i := 0 + for _, file := range files { + name := filepath.Base(file.Name()) + if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) { + continue + } + if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 { + continue + } + configs[i] = strings.TrimSuffix(name, configFileSuffix) + i++ + } + return configs[:i], nil +} + +func MigrateUnencryptedConfigs() (int, []error) { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return 0, []error{err} + } + files, err := ioutil.ReadDir(configFileDir) + if err != nil { + return 0, []error{err} + } + errs := make([]error, len(files)) + i := 0 + e := 0 + for _, file := range files { + path := filepath.Join(configFileDir, file.Name()) + name := filepath.Base(file.Name()) + if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) { + continue + } + if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 { + continue + } + + // We don't use ioutil's ReadFile, because we actually want RDWR, so that we can take advantage + // of Windows file locking for ensuring the file is finished being written. + f, err := os.OpenFile(path, os.O_RDWR, 0) + if err != nil { + errs[e] = err + e++ + continue + } + bytes, err := ioutil.ReadAll(f) + f.Close() + if err != nil { + errs[e] = err + e++ + continue + } + _, err = FromWgQuick(string(bytes), "input") + if err != nil { + errs[e] = err + e++ + continue + } + + bytes, err = dpapi.Encrypt(bytes, strings.TrimSuffix(name, configFileUnencryptedSuffix)) + if err != nil { + errs[e] = err + e++ + continue + } + dstFile := strings.TrimSuffix(path, configFileUnencryptedSuffix) + configFileSuffix + if _, err = os.Stat(dstFile); err != nil && !os.IsNotExist(err) { + errs[e] = errors.New("Unable to migrate to " + dstFile + " as it already exists") + e++ + continue + } + err = ioutil.WriteFile(dstFile, bytes, 0600) + if err != nil { + errs[e] = err + e++ + continue + } + err = os.Remove(path) + if err != nil && os.Remove(dstFile) == nil { + errs[e] = err + e++ + continue + } + i++ + } + return i, errs[:e] +} + +func LoadFromName(name string) (*Config, error) { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return nil, err + } + return LoadFromPath(filepath.Join(configFileDir, name+configFileSuffix)) +} + +func LoadFromPath(path string) (*Config, error) { + name, err := NameFromPath(path) + if err != nil { + return nil, err + } + bytes, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + if strings.HasSuffix(path, configFileSuffix) { + bytes, err = dpapi.Decrypt(bytes, name) + if err != nil { + return nil, err + } + } + return FromWgQuick(string(bytes), name) +} + +func NameFromPath(path string) (string, error) { + name := filepath.Base(path) + if !((len(name) > len(configFileSuffix) && strings.HasSuffix(name, configFileSuffix)) || + (len(name) > len(configFileUnencryptedSuffix) && strings.HasSuffix(name, configFileUnencryptedSuffix))) { + return "", errors.New("Path must end in either " + configFileSuffix + " or " + configFileUnencryptedSuffix) + } + if strings.HasSuffix(path, configFileSuffix) { + name = strings.TrimSuffix(name, configFileSuffix) + } else { + name = strings.TrimSuffix(name, configFileUnencryptedSuffix) + } + return name, nil +} + +func (config *Config) Save() error { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return err + } + filename := filepath.Join(configFileDir, config.Name+configFileSuffix) + bytes := []byte(config.ToWgQuick()) + bytes, err = dpapi.Encrypt(bytes, config.Name) + if err != nil { + return err + } + err = ioutil.WriteFile(filename+".tmp", bytes, 0600) + if err != nil { + return err + } + err = os.Rename(filename+".tmp", filename) + if err != nil { + os.Remove(filename + ".tmp") + return err + } + return nil +} + +func (config *Config) Path() (string, error) { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return "", err + } + return filepath.Join(configFileDir, config.Name+configFileSuffix), nil +} + +func DeleteName(name string) error { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return err + } + return os.Remove(filepath.Join(configFileDir, name+configFileSuffix)) +} + +func (config *Config) Delete() error { + return DeleteName(config.Name) +} diff --git a/conf/store_test.go b/conf/store_test.go new file mode 100644 index 00000000..fdef7ea7 --- /dev/null +++ b/conf/store_test.go @@ -0,0 +1,91 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "reflect" + "testing" +) + +func TestStorage(t *testing.T) { + c, err := FromWgQuick(testInput, "golangTest") + if err != nil { + t.Errorf("Unable to parse test config: %s", err.Error()) + return + } + + err = c.Save() + if err != nil { + t.Errorf("Unable to save config: %s", err.Error()) + } + + configs, err := ListConfigNames() + if err != nil { + t.Errorf("Unable to list configs: %s", err.Error()) + } + + found := false + for _, name := range configs { + if name == "golangTest" { + found = true + break + } + } + if !found { + t.Error("Unable to find saved config in list") + } + + loaded, err := LoadFromName("golangTest") + if err != nil { + t.Errorf("Unable to load config: %s", err.Error()) + return + } + + if !reflect.DeepEqual(loaded, c) { + t.Error("Loaded config is not the same as saved config") + } + + k, err := NewPrivateKey() + if err != nil { + t.Errorf("Unable to generate new private key: %s", err.Error()) + } + c.Interface.PrivateKey = *k + + err = c.Save() + if err != nil { + t.Errorf("Unable to save config a second time: %s", err.Error()) + } + + loaded, err = LoadFromName("golangTest") + if err != nil { + t.Errorf("Unable to load config a second time: %s", err.Error()) + return + } + + if !reflect.DeepEqual(loaded, c) { + t.Error("Second loaded config is not the same as second saved config") + } + + err = DeleteName("golangTest") + if err != nil { + t.Errorf("Unable to delete config: %s", err.Error()) + } + + configs, err = ListConfigNames() + if err != nil { + t.Errorf("Unable to list configs: %s", err.Error()) + } + found = false + for _, name := range configs { + if name == "golangTest" { + found = true + break + } + } + if found { + t.Error("Config wasn't actually deleted") + } +} diff --git a/conf/storewatcher.go b/conf/storewatcher.go new file mode 100644 index 00000000..4f9c2ef7 --- /dev/null +++ b/conf/storewatcher.go @@ -0,0 +1,38 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import "reflect" + +type StoreCallback func() + +var storeCallbacks []StoreCallback + +func RegisterStoreChangeCallback(cb StoreCallback) { + startWatchingConfigDir() + cb() + storeCallbacks = append(storeCallbacks, cb) +} + +func UnregisterStoreChangeCallback(cb StoreCallback) { + //TODO: this function is ridiculous, doing slow iteration like this and reflection too. + + index := -1 + for i, e := range storeCallbacks { + if reflect.ValueOf(e).Pointer() == reflect.ValueOf(cb).Pointer() { + index = i + break + } + } + if index == -1 { + return + } + newList := storeCallbacks[0:index] + if index < len(storeCallbacks)-1 { + newList = append(newList, storeCallbacks[index+1:]...) + } + storeCallbacks = newList +} diff --git a/conf/storewatcher_windows.go b/conf/storewatcher_windows.go new file mode 100644 index 00000000..f3f38fef --- /dev/null +++ b/conf/storewatcher_windows.go @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "golang.org/x/sys/windows" + "log" +) + +const ( + fncFILE_NAME uint32 = 0x00000001 + fncDIR_NAME uint32 = 0x00000002 + fncATTRIBUTES uint32 = 0x00000004 + fncSIZE uint32 = 0x00000008 + fncLAST_WRITE uint32 = 0x00000010 + fncLAST_ACCESS uint32 = 0x00000020 + fncCREATION uint32 = 0x00000040 + fncSECURITY uint32 = 0x00000100 +) + +//sys findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) = kernel32.FindFirstChangeNotificationW +//sys findNextChangeNotification(handle windows.Handle) (err error) = kernel32.FindNextChangeNotification + +var haveStartedWatchingConfigDir bool + +func startWatchingConfigDir() { + if haveStartedWatchingConfigDir { + return + } + haveStartedWatchingConfigDir = true + go func() { + configFileDir, err := resolveConfigFileDir() + if err != nil { + return + } + h, err := findFirstChangeNotification(windows.StringToUTF16Ptr(configFileDir), true, fncFILE_NAME|fncDIR_NAME|fncATTRIBUTES|fncSIZE|fncLAST_WRITE|fncLAST_ACCESS|fncCREATION|fncSECURITY) + if err != nil { + log.Fatalf("Unable to monitor config directory: %v", err) + } + for { + s, err := windows.WaitForSingleObject(h, windows.INFINITE) + if err != nil || s == windows.WAIT_FAILED { + log.Fatalf("Unable to wait on config directory watcher: %v", err) + } + + for _, cb := range storeCallbacks { + cb() + } + + err = findNextChangeNotification(h) + if err != nil { + log.Fatalf("Unable to monitor config directory again: %v", err) + } + } + }() +} diff --git a/conf/writer.go b/conf/writer.go new file mode 100644 index 00000000..642d14a7 --- /dev/null +++ b/conf/writer.go @@ -0,0 +1,125 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "fmt" + "net" + "strings" +) + +func (conf *Config) ToWgQuick() string { + var output strings.Builder + output.WriteString("[Interface]\n") + + output.WriteString(fmt.Sprintf("PrivateKey = %s\n", conf.Interface.PrivateKey.String())) + + if conf.Interface.ListenPort > 0 { + output.WriteString(fmt.Sprintf("ListenPort = %d\n", conf.Interface.ListenPort)) + } + + if len(conf.Interface.Addresses) > 0 { + addrStrings := make([]string, len(conf.Interface.Addresses)) + for i, address := range conf.Interface.Addresses { + addrStrings[i] = address.String() + } + output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if len(conf.Interface.Dns) > 0 { + addrStrings := make([]string, len(conf.Interface.Dns)) + for i, address := range conf.Interface.Dns { + addrStrings[i] = address.String() + } + output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if conf.Interface.Mtu > 0 { + output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.Mtu)) + } + + for _, peer := range conf.Peers { + output.WriteString("\n[Peer]\n") + + output.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey.String())) + + if !peer.PresharedKey.IsZero() { + output.WriteString(fmt.Sprintf("PresharedKey = %s\n", peer.PresharedKey.String())) + } + + if len(peer.AllowedIPs) > 0 { + addrStrings := make([]string, len(peer.AllowedIPs)) + for i, address := range peer.AllowedIPs { + addrStrings[i] = address.String() + } + output.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if !peer.Endpoint.IsEmpty() { + output.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint.String())) + } + + if peer.PersistentKeepalive > 0 { + output.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", peer.PersistentKeepalive)) + } + } + return output.String() +} + +func (conf *Config) ToUAPI() (string, error) { + var output strings.Builder + output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString())) + + if conf.Interface.ListenPort > 0 { + output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort)) + } + + if len(conf.Peers) > 0 { + output.WriteString("replace_peers=true\n") + } + + for _, peer := range conf.Peers { + output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString())) + + if !peer.PresharedKey.IsZero() { + output.WriteString(fmt.Sprintf("preshared_key = %s\n", peer.PresharedKey.String())) + } + + if !peer.Endpoint.IsEmpty() { + ips, err := net.LookupIP(peer.Endpoint.Host) + if err != nil { + return "", err + } + var ip net.IP + for _, iterip := range ips { + iterip = iterip.To4() + if iterip != nil { + ip = iterip + break + } + if ip == nil { + ip = iterip + } + } + if ip == nil { + return "", errors.New("Unable to resolve IP address of endpoint") + } + resolvedEndpoint := Endpoint{ip.String(), peer.Endpoint.Port} + output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String())) + } + + output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive)) + + if len(peer.AllowedIPs) > 0 { + output.WriteString("replace_allowed_ips=true\n") + for _, address := range peer.AllowedIPs { + output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String())) + } + } + } + return output.String(), nil +} diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go new file mode 100644 index 00000000..64bec1f4 --- /dev/null +++ b/conf/zsyscall_windows.go @@ -0,0 +1,96 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package conf + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modole32 = windows.NewLazySystemDLL("ole32.dll") + modshell32 = windows.NewLazySystemDLL("shell32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procCoTaskMemFree = modole32.NewProc("CoTaskMemFree") + procSHGetKnownFolderPath = modshell32.NewProc("SHGetKnownFolderPath") + procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW") + procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification") +) + +func coTaskMemFree(pointer uintptr) { + syscall.Syscall(procCoTaskMemFree.Addr(), 1, uintptr(pointer), 0, 0) + return +} + +func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) { + r1, _, e1 := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0) + if r1 != 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) { + var _p0 uint32 + if watchSubtree { + _p0 = 1 + } else { + _p0 = 0 + } + r0, _, e1 := syscall.Syscall(procFindFirstChangeNotificationW.Addr(), 3, uintptr(unsafe.Pointer(path)), uintptr(_p0), uintptr(filter)) + handle = windows.Handle(r0) + if handle == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func findNextChangeNotification(handle windows.Handle) (err error) { + r1, _, e1 := syscall.Syscall(procFindNextChangeNotification.Addr(), 1, uintptr(handle), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} -- cgit v1.2.3-59-g8ed1b