aboutsummaryrefslogtreecommitdiffstats
path: root/conf
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-02-25 18:45:32 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2019-02-28 08:05:02 +0100
commit840f33de326233d5fee1144334db41bf5c82a8fa (patch)
tree43070181e30db403dfad69f3e67a566ba589df4e /conf
parentInitial scaffolding (diff)
downloadwireguard-windows-840f33de326233d5fee1144334db41bf5c82a8fa.tar.xz
wireguard-windows-840f33de326233d5fee1144334db41bf5c82a8fa.zip
conf: introduce configuration management
Diffstat (limited to 'conf')
-rw-r--r--conf/config.go180
-rw-r--r--conf/dpapi/dpapi_windows.go107
-rw-r--r--conf/dpapi/dpapi_windows_test.go79
-rw-r--r--conf/dpapi/mksyscall.go8
-rw-r--r--conf/dpapi/zdpapi_windows.go68
-rw-r--r--conf/mksyscall.go8
-rw-r--r--conf/parser.go454
-rw-r--r--conf/parser_test.go128
-rw-r--r--conf/path_windows.go50
-rw-r--r--conf/store.go199
-rw-r--r--conf/store_test.go91
-rw-r--r--conf/storewatcher.go38
-rw-r--r--conf/storewatcher_windows.go59
-rw-r--r--conf/writer.go125
-rw-r--r--conf/zsyscall_windows.go96
15 files changed, 1690 insertions, 0 deletions
diff --git a/conf/config.go b/conf/config.go
new file mode 100644
index 00000000..a321bc0c
--- /dev/null
+++ b/conf/config.go
@@ -0,0 +1,180 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/curve25519"
+)
+
+const KeyLength = 32
+
+type IPCidr struct {
+ IP net.IP
+ Cidr uint8
+}
+
+type Endpoint struct {
+ Host string
+ Port uint16
+}
+
+type Key [KeyLength]byte
+type HandshakeTime time.Duration
+type Bytes uint64
+
+type Config struct {
+ Name string
+ Interface Interface
+ Peers []Peer
+}
+
+type Interface struct {
+ PrivateKey Key
+ Addresses []IPCidr
+ ListenPort uint16
+ Mtu uint16
+ Dns []net.IP
+}
+
+type Peer struct {
+ PublicKey Key
+ PresharedKey Key
+ AllowedIPs []IPCidr
+ Endpoint Endpoint
+ PersistentKeepalive uint16
+
+ RxBytes Bytes
+ TxBytes Bytes
+ LastHandshakeTime HandshakeTime
+}
+
+func (r *IPCidr) String() string {
+ return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
+}
+
+func (e *Endpoint) String() string {
+ if strings.IndexByte(e.Host, ':') > 0 {
+ return fmt.Sprintf("[%s]:%d", e.Host, e.Port)
+ }
+ return fmt.Sprintf("%s:%d", e.Host, e.Port)
+}
+
+func (e *Endpoint) IsEmpty() bool {
+ return len(e.Host) == 0
+}
+
+func (k *Key) String() string {
+ return base64.StdEncoding.EncodeToString(k[:])
+}
+
+func (k *Key) HexString() string {
+ return hex.EncodeToString(k[:])
+}
+
+func (k *Key) IsZero() bool {
+ var zeros Key
+ return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
+}
+
+func (k *Key) Public() *Key {
+ var p [KeyLength]byte
+ curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k))
+ return (*Key)(&p)
+}
+
+func NewPresharedKey() (*Key, error) {
+ var k [KeyLength]byte
+ _, err := rand.Read(k[:])
+ if err != nil {
+ return nil, err
+ }
+ return (*Key)(&k), nil
+}
+
+func NewPrivateKey() (*Key, error) {
+ k, err := NewPresharedKey()
+ if err != nil {
+ return nil, err
+ }
+ k[0] &= 248
+ k[31] = (k[31] & 127) | 64
+ return k, nil
+}
+
+func formatInterval(i int64, n string, l int) string {
+ r := ""
+ if l > 0 {
+ r += ", "
+ }
+ r += fmt.Sprintf("%d %s", i, n)
+ if i != 1 {
+ r += "s"
+ }
+ return r
+}
+
+func (t HandshakeTime) IsEmpty() bool {
+ return t == HandshakeTime(0)
+}
+
+func (t HandshakeTime) String() string {
+ u := time.Unix(0, 0).Add(time.Duration(t)).Unix()
+ n := time.Now().Unix()
+ if u == n {
+ return "Now"
+ } else if u > n {
+ return "System clock wound backward!"
+ }
+ left := n - u
+ years := left / (365 * 24 * 60 * 60)
+ left = left % (365 * 24 * 60 * 60)
+ days := left / (24 * 60 * 60)
+ left = left % (24 * 60 * 60)
+ hours := left / (60 * 60)
+ left = left % (60 * 60)
+ minutes := left / 60
+ seconds := left % 60
+ s := ""
+ if years > 0 {
+ s += formatInterval(years, "year", len(s))
+ }
+ if days > 0 {
+ s += formatInterval(days, "day", len(s))
+ }
+ if hours > 0 {
+ s += formatInterval(hours, "hour", len(s))
+ }
+ if minutes > 0 {
+ s += formatInterval(minutes, "minute", len(s))
+ }
+ if seconds > 0 {
+ s += formatInterval(seconds, "second", len(s))
+ }
+ s += " ago"
+ return s
+}
+
+func (b Bytes) String() string {
+ if b < 1024 {
+ return fmt.Sprintf("%d B", b)
+ } else if b < 1024*1024 {
+ return fmt.Sprintf("%.2f KiB", float64(b)/1024)
+ } else if b < 1024*1024*1024 {
+ return fmt.Sprintf("%.2f MiB", float64(b)/(1024*1024))
+ } else if b < 1024*1024*1024*1024 {
+ return fmt.Sprintf("%.2f GiB", float64(b)/(1024*1024*1024))
+ }
+ return fmt.Sprintf("%.2f TiB", float64(b)/(1024*1024*1024)/1024)
+}
diff --git a/conf/dpapi/dpapi_windows.go b/conf/dpapi/dpapi_windows.go
new file mode 100644
index 00000000..03a5d8a3
--- /dev/null
+++ b/conf/dpapi/dpapi_windows.go
@@ -0,0 +1,107 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package dpapi
+
+import (
+ "errors"
+ "golang.org/x/sys/windows"
+ "runtime"
+ "unsafe"
+)
+
+const (
+ dpCRYPTPROTECT_UI_FORBIDDEN uint32 = 0x1
+ dpCRYPTPROTECT_LOCAL_MACHINE uint32 = 0x4
+ dpCRYPTPROTECT_CRED_SYNC uint32 = 0x8
+ dpCRYPTPROTECT_AUDIT uint32 = 0x10
+ dpCRYPTPROTECT_NO_RECOVERY uint32 = 0x20
+ dpCRYPTPROTECT_VERIFY_PROTECTION uint32 = 0x40
+ dpCRYPTPROTECT_CRED_REGENERATE uint32 = 0x80
+)
+
+type dpBlob struct {
+ len uint32
+ data uintptr
+}
+
+func bytesToBlob(bytes []byte) *dpBlob {
+ blob := &dpBlob{}
+ blob.len = uint32(len(bytes))
+ if len(bytes) > 0 {
+ blob.data = uintptr(unsafe.Pointer(&bytes[0]))
+ }
+ return blob
+}
+
+//sys cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptProtectData
+
+func Encrypt(data []byte, name string) ([]byte, error) {
+ out := dpBlob{}
+ err := cryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out)
+ if err != nil {
+ return nil, errors.New("Unable to encrypt DPAPI protected data: " + err.Error())
+ }
+
+ outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
+ addr uintptr
+ len int
+ cap int
+ }{out.data, int(out.len), int(out.len)})))
+ ret := make([]byte, len(outSlice))
+ copy(ret, outSlice)
+ windows.LocalFree(windows.Handle(out.data))
+
+ return ret, nil
+}
+
+//sys cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptUnprotectData
+
+func Decrypt(data []byte, name string) ([]byte, error) {
+ out := dpBlob{}
+ var outName *uint16
+ utf16Name, err := windows.UTF16PtrFromString(name)
+ if err != nil {
+ return nil, err
+ }
+
+ err = cryptUnprotectData(bytesToBlob(data), &outName, nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out)
+ if err != nil {
+ return nil, errors.New("Unable to decrypt DPAPI protected data: " + err.Error())
+ }
+
+ outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
+ addr uintptr
+ len int
+ cap int
+ }{out.data, int(out.len), int(out.len)})))
+ ret := make([]byte, len(outSlice))
+ copy(ret, outSlice)
+ windows.LocalFree(windows.Handle(out.data))
+
+ // Note: this ridiculous open-coded strcmp is not constant time.
+ different := false
+ a := outName
+ b := utf16Name
+ for {
+ if *a != *b {
+ different = true
+ break
+ }
+ if *a == 0 || *b == 0 {
+ break
+ }
+ a = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(a)) + 2))
+ b = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(b)) + 2))
+ }
+ runtime.KeepAlive(utf16Name)
+ windows.LocalFree(windows.Handle(unsafe.Pointer(outName)))
+
+ if different {
+ return nil, errors.New("The input name does not match the stored name")
+ }
+
+ return ret, nil
+}
diff --git a/conf/dpapi/dpapi_windows_test.go b/conf/dpapi/dpapi_windows_test.go
new file mode 100644
index 00000000..e0e9b42d
--- /dev/null
+++ b/conf/dpapi/dpapi_windows_test.go
@@ -0,0 +1,79 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package dpapi
+
+import (
+ "bytes"
+ "golang.org/x/sys/windows"
+ "testing"
+ "unsafe"
+)
+
+func TestRoundTrip(t *testing.T) {
+ name := "golang test"
+ original := []byte("The quick brown fox jumped over the lazy dog")
+
+ e, err := Encrypt(original, name)
+ if err != nil {
+ t.Errorf("Error encrypting: %s", err.Error())
+ }
+
+ if len(e) < len(original) {
+ t.Error("Encrypted data is smaller than original data.")
+ }
+
+ d, err := Decrypt(e, name)
+ if err != nil {
+ t.Errorf("Error decrypting: %s", err.Error())
+ }
+
+ if !bytes.Equal(d, original) {
+ t.Error("Decrypted content does not match original")
+ }
+
+ _, err = Decrypt(e, "bad name")
+ if err == nil {
+ t.Error("Decryption failed to notice ad mismatch")
+ }
+
+ eCorrupt := make([]byte, len(e))
+ copy(eCorrupt, e)
+ eCorrupt[len(original)-1] = 7
+ _, err = Decrypt(eCorrupt, name)
+ if err == nil {
+ t.Error("Decryption failed to notice ciphertext corruption")
+ }
+
+ copy(eCorrupt, e)
+ nameUtf16, err := windows.UTF16FromString(name)
+ if err != nil {
+ t.Errorf("Unable to get utf16 chars for name: %s", err)
+ }
+ nameUtf16Bytes := *(*[]byte)(unsafe.Pointer(&struct {
+ addr *byte
+ len int
+ cap int
+ }{(*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16) * 2, cap(nameUtf16) * 2}))
+ i := bytes.Index(eCorrupt, nameUtf16Bytes)
+ if i == -1 {
+ t.Error("Unable to find ad in blob")
+ } else {
+ eCorrupt[i] = 7
+ _, err = Decrypt(eCorrupt, name)
+ if err == nil {
+ t.Error("Decryption failed to notice ad corruption")
+ }
+ }
+
+ // BUG: Actually, Windows doesn't report length extension of the buffer, unfortunately.
+ //
+ // eCorrupt = make([]byte, len(e)+1)
+ // copy(eCorrupt, e)
+ // _, err = Decrypt(eCorrupt, name)
+ // if err == nil {
+ // t.Error("Decryption failed to notice length extension")
+ // }
+}
diff --git a/conf/dpapi/mksyscall.go b/conf/dpapi/mksyscall.go
new file mode 100644
index 00000000..f80c3fd2
--- /dev/null
+++ b/conf/dpapi/mksyscall.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package dpapi
+
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zdpapi_windows.go dpapi_windows.go
diff --git a/conf/dpapi/zdpapi_windows.go b/conf/dpapi/zdpapi_windows.go
new file mode 100644
index 00000000..e48d36b2
--- /dev/null
+++ b/conf/dpapi/zdpapi_windows.go
@@ -0,0 +1,68 @@
+// Code generated by 'go generate'; DO NOT EDIT.
+
+package dpapi
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+ errnoERROR_IO_PENDING = 997
+)
+
+var (
+ errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+ switch e {
+ case 0:
+ return nil
+ case errnoERROR_IO_PENDING:
+ return errERROR_IO_PENDING
+ }
+ // TODO: add more here, after collecting data on the common
+ // error values see on Windows. (perhaps when running
+ // all.bat?)
+ return e
+}
+
+var (
+ modcrypt32 = windows.NewLazySystemDLL("crypt32.dll")
+
+ procCryptProtectData = modcrypt32.NewProc("CryptProtectData")
+ procCryptUnprotectData = modcrypt32.NewProc("CryptUnprotectData")
+)
+
+func cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) {
+ r1, _, e1 := syscall.Syscall9(procCryptProtectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0)
+ if r1 == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) {
+ r1, _, e1 := syscall.Syscall9(procCryptUnprotectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0)
+ if r1 == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
diff --git a/conf/mksyscall.go b/conf/mksyscall.go
new file mode 100644
index 00000000..2bdb0204
--- /dev/null
+++ b/conf/mksyscall.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go path_windows.go storewatcher_windows.go
diff --git a/conf/parser.go b/conf/parser.go
new file mode 100644
index 00000000..6a397a9d
--- /dev/null
+++ b/conf/parser.go
@@ -0,0 +1,454 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type ParseError struct {
+ why string
+ offender string
+}
+
+func (e *ParseError) Error() string {
+ return fmt.Sprintf("%s: ā€˜%sā€™", e.why, e.offender)
+}
+
+func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
+ var addrStr, cidrStr string
+ var cidr int
+
+ i := strings.IndexByte(s, '/')
+ if i < 0 {
+ addrStr = s
+ } else {
+ addrStr, cidrStr = s[:i], s[i+1:]
+ }
+
+ err = &ParseError{"Invalid IP address", s}
+ addr := net.ParseIP(addrStr)
+ if addr == nil {
+ return
+ }
+ if len(cidrStr) > 0 {
+ err = &ParseError{"Invalid network prefix length", s}
+ cidr, err = strconv.Atoi(cidrStr)
+ if err != nil || cidr < 0 || cidr > 128 {
+ return
+ }
+ if cidr > 32 && addr.To4() != nil {
+ return
+ }
+ } else {
+ if addr.To4() != nil {
+ cidr = 32
+ } else {
+ cidr = 128
+ }
+ }
+ return &IPCidr{addr, uint8(cidr)}, nil
+}
+
+func parseEndpoint(s string) (*Endpoint, error) {
+ i := strings.LastIndexByte(s, ':')
+ if i < 0 {
+ return nil, &ParseError{"Missing port from endpoint", s}
+ }
+ host, portStr := s[:i], s[i+1:]
+ if len(host) < 1 {
+ return nil, &ParseError{"Invalid endpoint host", host}
+ }
+ port, err := parsePort(portStr)
+ if err != nil {
+ return nil, err
+ }
+ hostColon := strings.IndexByte(host, ':')
+ if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
+ err := &ParseError{"Brackets must contain an IPv6 address", host}
+ if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
+ maybeV6 := net.ParseIP(host[1 : len(host)-1])
+ if maybeV6 == nil || len(maybeV6) != net.IPv6len {
+ return nil, err
+ }
+ } else {
+ return nil, err
+ }
+ host = host[1 : len(host)-1]
+ }
+ return &Endpoint{host, uint16(port)}, nil
+}
+
+func parseMTU(s string) (uint16, error) {
+ m, err := strconv.Atoi(s)
+ if err != nil {
+ return 0, err
+ }
+ if m < 576 || m > 65535 {
+ return 0, &ParseError{"Invalid MTU", s}
+ }
+ return uint16(m), nil
+}
+
+func parsePort(s string) (uint16, error) {
+ m, err := strconv.Atoi(s)
+ if err != nil {
+ return 0, err
+ }
+ if m < 0 || m > 65535 {
+ return 0, &ParseError{"Invalid port", s}
+ }
+ return uint16(m), nil
+}
+
+func parsePersistentKeepalive(s string) (uint16, error) {
+ if s == "off" {
+ return 0, nil
+ }
+ m, err := strconv.Atoi(s)
+ if err != nil {
+ return 0, err
+ }
+ if m < 0 || m > 65535 {
+ return 0, &ParseError{"Invalid persistent keepalive", s}
+ }
+ return uint16(m), nil
+}
+
+func parseKeyBase64(s string) (*Key, error) {
+ k, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return nil, &ParseError{"Invalid key: " + err.Error(), s}
+ }
+ if len(k) != KeyLength {
+ return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
+ }
+ var key Key
+ copy(key[:], k)
+ return &key, nil
+}
+
+func parseKeyHex(s string) (*Key, error) {
+ k, err := hex.DecodeString(s)
+ if err != nil {
+ return nil, &ParseError{"Invalid key: " + err.Error(), s}
+ }
+ if len(k) != KeyLength {
+ return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
+ }
+ var key Key
+ copy(key[:], k)
+ return &key, nil
+}
+
+func parseBytesOrStamp(s string) (uint64, error) {
+ b, err := strconv.ParseUint(s, 10, 64)
+ if err != nil {
+ return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s}
+ }
+ return b, nil
+}
+
+func splitList(s string) ([]string, error) {
+ var out []string
+ for _, split := range strings.Split(s, ",") {
+ trim := strings.TrimSpace(split)
+ if len(trim) == 0 {
+ return nil, &ParseError{"Two commas in a row", s}
+ }
+ out = append(out, trim)
+ }
+ return out, nil
+}
+
+type parserState int
+
+const (
+ inInterfaceSection parserState = iota
+ inPeerSection
+ notInASection
+)
+
+func (c *Config) maybeAddPeer(p *Peer) {
+ if p != nil {
+ c.Peers = append(c.Peers, *p)
+ }
+}
+
+func FromWgQuick(s string, name string) (*Config, error) {
+ lines := strings.Split(s, "\n")
+ parserState := notInASection
+ conf := Config{Name: name}
+ sawPrivateKey := false
+ var peer *Peer
+ for _, line := range lines {
+ pound := strings.IndexByte(line, '#')
+ if pound >= 0 {
+ line = line[:pound]
+ }
+ line = strings.TrimSpace(line)
+ lineLower := strings.ToLower(line)
+ if len(line) == 0 {
+ continue
+ }
+ if lineLower == "[interface]" {
+ conf.maybeAddPeer(peer)
+ parserState = inInterfaceSection
+ continue
+ }
+ if lineLower == "[peer]" {
+ conf.maybeAddPeer(peer)
+ peer = &Peer{}
+ parserState = inPeerSection
+ continue
+ }
+ if parserState == notInASection {
+ return nil, &ParseError{"Line must occur in a section", line}
+ }
+ equals := strings.IndexByte(line, '=')
+ if equals < 0 {
+ return nil, &ParseError{"Invalid config key is missing an equals separator", line}
+ }
+ key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
+ if len(val) == 0 {
+ return nil, &ParseError{"Key must have a value", line}
+ }
+ if parserState == inInterfaceSection {
+ switch key {
+ case "privatekey":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.PrivateKey = *k
+ sawPrivateKey = true
+ case "listenport":
+ p, err := parsePort(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.ListenPort = p
+ case "mtu":
+ m, err := parseMTU(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.Mtu = m
+ case "address":
+ addresses, err := splitList(val)
+ if err != nil {
+ return nil, err
+ }
+ for _, address := range addresses {
+ a, err := parseIPCidr(address)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
+ }
+ case "dns":
+ addresses, err := splitList(val)
+ if err != nil {
+ return nil, err
+ }
+ for _, address := range addresses {
+ a := net.ParseIP(address)
+ if a == nil {
+ return nil, &ParseError{"Invalid IP address", address}
+ }
+ conf.Interface.Dns = append(conf.Interface.Dns, a)
+ }
+ default:
+ return nil, &ParseError{"Invalid key for [Interface] section", key}
+ }
+ } else if parserState == inPeerSection {
+ switch key {
+ case "publickey":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PublicKey = *k
+ case "presharedkey":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PresharedKey = *k
+ case "allowedips":
+ addresses, err := splitList(val)
+ if err != nil {
+ return nil, err
+ }
+ for _, address := range addresses {
+ a, err := parseIPCidr(address)
+ if err != nil {
+ return nil, err
+ }
+ peer.AllowedIPs = append(peer.AllowedIPs, *a)
+ }
+ case "persistentkeepalive":
+ p, err := parsePersistentKeepalive(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PersistentKeepalive = p
+ case "endpoint":
+ e, err := parseEndpoint(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.Endpoint = *e
+ default:
+ return nil, &ParseError{"Invalid key for [Peer] section", key}
+ }
+ }
+ }
+ conf.maybeAddPeer(peer)
+
+ if !sawPrivateKey {
+ return nil, &ParseError{"An interface must have a private key", "[none specified]"}
+ }
+ for _, p := range conf.Peers {
+ if p.PublicKey.IsZero() {
+ return nil, &ParseError{"All peers must have public keys", "[none specified]"}
+ }
+ }
+
+ return &conf, nil
+}
+
+func FromUAPI(s string, existingConfig *Config) (*Config, error) {
+ lines := strings.Split(s, "\n")
+ parserState := inInterfaceSection
+ conf := Config{
+ Name: existingConfig.Name,
+ Interface: Interface{
+ Addresses: existingConfig.Interface.Addresses,
+ Dns: existingConfig.Interface.Dns,
+ Mtu: existingConfig.Interface.Mtu,
+ },
+ }
+ var peer *Peer
+ for _, line := range lines {
+ if len(line) == 0 {
+ continue
+ }
+ equals := strings.IndexByte(line, '=')
+ if equals < 0 {
+ return nil, &ParseError{"Invalid config key is missing an equals separator", line}
+ }
+ key, val := line[:equals], line[equals+1:]
+ if len(val) == 0 {
+ return nil, &ParseError{"Key must have a value", line}
+ }
+ switch key {
+ case "public_key":
+ conf.maybeAddPeer(peer)
+ peer = &Peer{}
+ parserState = inPeerSection
+ case "errno":
+ if val == "0" {
+ continue
+ } else {
+ return nil, &ParseError{"Error in getting configuration", val}
+ }
+ }
+ if parserState == inInterfaceSection {
+ switch key {
+ case "private_key":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.PrivateKey = *k
+ case "listen_port":
+ p, err := parsePort(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.ListenPort = p
+ case "fwmark":
+ // Ignored for now.
+
+ default:
+ return nil, &ParseError{"Invalid key for interface section", key}
+ }
+ } else if parserState == inPeerSection {
+ switch key {
+ case "public_key":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PublicKey = *k
+ case "preshared_key":
+ k, err := parseKeyBase64(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PresharedKey = *k
+ case "protocol_version":
+ if val != "1" {
+ return nil, &ParseError{"Protocol version must be 1", val}
+ }
+ case "allowed_ip":
+ a, err := parseIPCidr(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.AllowedIPs = append(peer.AllowedIPs, *a)
+ case "persistent_keepalive_interval":
+ p, err := parsePersistentKeepalive(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.PersistentKeepalive = p
+ case "endpoint":
+ e, err := parseEndpoint(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.Endpoint = *e
+ case "tx_bytes":
+ b, err := parseBytesOrStamp(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.TxBytes = Bytes(b)
+ case "rx_bytes":
+ b, err := parseBytesOrStamp(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.RxBytes = Bytes(b)
+ case "last_handshake_time_sec":
+ t, err := parseBytesOrStamp(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second)
+ case "last_handshake_time_nsec":
+ t, err := parseBytesOrStamp(val)
+ if err != nil {
+ return nil, err
+ }
+ peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond)
+ default:
+ return nil, &ParseError{"Invalid key for peer section", key}
+ }
+ }
+ }
+ conf.maybeAddPeer(peer)
+
+ return &conf, nil
+}
diff --git a/conf/parser_test.go b/conf/parser_test.go
new file mode 100644
index 00000000..a6afbf53
--- /dev/null
+++ b/conf/parser_test.go
@@ -0,0 +1,128 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "net"
+ "reflect"
+ "runtime"
+ "testing"
+)
+
+const testInput = `
+[Interface]
+Address = 10.192.122.1/24
+Address = 10.10.0.1/16
+PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
+ListenPort = 51820 #comments don't matter
+
+[Peer]
+PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
+Endpoint = 192.95.5.67:1234
+AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
+
+[Peer]
+PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
+Endpoint = [2607:5300:60:6b0::c05f:543]:2468
+AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
+PersistentKeepalive = 100
+
+[Peer]
+PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
+PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
+Endpoint = test.wireguard.com:18981
+AllowedIPs = 10.10.10.230/32`
+
+func noError(t *testing.T, err error) bool {
+ if err == nil {
+ return true
+ }
+ _, fn, line, _ := runtime.Caller(1)
+ t.Errorf("Error at %s:%d: %#v", fn, line, err)
+ return false
+}
+
+func equal(t *testing.T, expected, actual interface{}) bool {
+ if reflect.DeepEqual(expected, actual) {
+ return true
+ }
+ _, fn, line, _ := runtime.Caller(1)
+ t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
+ return false
+}
+func lenTest(t *testing.T, actualO interface{}, expected int) bool {
+ actual := reflect.ValueOf(actualO).Len()
+ if reflect.DeepEqual(expected, actual) {
+ return true
+ }
+ _, fn, line, _ := runtime.Caller(1)
+ t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
+ return false
+}
+func contains(t *testing.T, list, element interface{}) bool {
+ listValue := reflect.ValueOf(list)
+ for i := 0; i < listValue.Len(); i++ {
+ if reflect.DeepEqual(listValue.Index(i).Interface(), element) {
+ return true
+ }
+ }
+ _, fn, line, _ := runtime.Caller(1)
+ t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element)
+ return false
+}
+
+func TestFromWgQuick(t *testing.T) {
+ conf, err := FromWgQuick(testInput, "test")
+ if noError(t, err) {
+
+ lenTest(t, conf.Interface.Addresses, 2)
+ contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
+ contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
+ equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
+ equal(t, uint16(51820), conf.Interface.ListenPort)
+
+ lenTest(t, conf.Peers, 3)
+ lenTest(t, conf.Peers[0].AllowedIPs, 2)
+ equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoint)
+ equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.String())
+
+ lenTest(t, conf.Peers[1].AllowedIPs, 2)
+ equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoint)
+ equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.String())
+ equal(t, uint16(100), conf.Peers[1].PersistentKeepalive)
+
+ lenTest(t, conf.Peers[2].AllowedIPs, 1)
+ equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoint)
+ equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.String())
+ equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.String())
+ }
+}
+
+func TestParseEndpoint(t *testing.T) {
+ _, err := parseEndpoint("[192.168.42.0:]:51880")
+ if err == nil {
+ t.Error("Error was expected")
+ }
+ e, err := parseEndpoint("192.168.42.0:51880")
+ if noError(t, err) {
+ equal(t, "192.168.42.0", e.Host)
+ equal(t, uint16(51880), e.Port)
+ }
+ e, err = parseEndpoint("test.wireguard.com:18981")
+ if noError(t, err) {
+ equal(t, "test.wireguard.com", e.Host)
+ equal(t, uint16(18981), e.Port)
+ }
+ e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468")
+ if noError(t, err) {
+ equal(t, "2607:5300:60:6b0::c05f:543", e.Host)
+ equal(t, uint16(2468), e.Port)
+ }
+ _, err = parseEndpoint("[::::::invalid:18981")
+ if err == nil {
+ t.Error("Error was expected")
+ }
+}
diff --git a/conf/path_windows.go b/conf/path_windows.go
new file mode 100644
index 00000000..0ee3fc73
--- /dev/null
+++ b/conf/path_windows.go
@@ -0,0 +1,50 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "errors"
+ "golang.org/x/sys/windows"
+ "os"
+ "path/filepath"
+ "unsafe"
+)
+
+//sys coTaskMemFree(pointer uintptr) = ole32.CoTaskMemFree
+//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) [failretval!=0] = shell32.SHGetKnownFolderPath
+var folderIDLocalAppData = windows.GUID{0xf1b32785, 0x6fba, 0x4fcf, [8]byte{0x9d, 0x55, 0x7b, 0x8e, 0x7f, 0x15, 0x70, 0x91}}
+
+const kfFlagCreate = 0x00008000
+
+var cachedConfigFileDir string
+
+func resolveConfigFileDir() (string, error) {
+ if cachedConfigFileDir != "" {
+ return cachedConfigFileDir, nil
+ }
+ processToken, err := windows.OpenCurrentProcessToken()
+ if err != nil {
+ return "", err
+ }
+ defer processToken.Close()
+ var path *uint16
+ err = shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, windows.Handle(processToken), &path)
+ if err != nil {
+ return "", err
+ }
+ defer coTaskMemFree(uintptr(unsafe.Pointer(path)))
+ root := windows.UTF16ToString((*[windows.MAX_LONG_PATH + 1]uint16)(unsafe.Pointer(path))[:])
+ if len(root) == 0 {
+ return "", errors.New("Unable to determine configuration directory")
+ }
+ c := filepath.Join(root, "WireGuard", "Configurations")
+ err = os.MkdirAll(c, os.ModeDir|0700)
+ if err != nil {
+ return "", err
+ }
+ cachedConfigFileDir = c
+ return cachedConfigFileDir, nil
+}
diff --git a/conf/store.go b/conf/store.go
new file mode 100644
index 00000000..7c110865
--- /dev/null
+++ b/conf/store.go
@@ -0,0 +1,199 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "errors"
+ "golang.zx2c4.com/wireguard/windows/conf/dpapi"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+)
+
+const configFileSuffix = ".conf.dpapi"
+const configFileUnencryptedSuffix = ".conf"
+
+func ListConfigNames() ([]string, error) {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return nil, err
+ }
+ files, err := ioutil.ReadDir(configFileDir)
+ if err != nil {
+ return nil, err
+ }
+ configs := make([]string, len(files))
+ i := 0
+ for _, file := range files {
+ name := filepath.Base(file.Name())
+ if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) {
+ continue
+ }
+ if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 {
+ continue
+ }
+ configs[i] = strings.TrimSuffix(name, configFileSuffix)
+ i++
+ }
+ return configs[:i], nil
+}
+
+func MigrateUnencryptedConfigs() (int, []error) {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return 0, []error{err}
+ }
+ files, err := ioutil.ReadDir(configFileDir)
+ if err != nil {
+ return 0, []error{err}
+ }
+ errs := make([]error, len(files))
+ i := 0
+ e := 0
+ for _, file := range files {
+ path := filepath.Join(configFileDir, file.Name())
+ name := filepath.Base(file.Name())
+ if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) {
+ continue
+ }
+ if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 {
+ continue
+ }
+
+ // We don't use ioutil's ReadFile, because we actually want RDWR, so that we can take advantage
+ // of Windows file locking for ensuring the file is finished being written.
+ f, err := os.OpenFile(path, os.O_RDWR, 0)
+ if err != nil {
+ errs[e] = err
+ e++
+ continue
+ }
+ bytes, err := ioutil.ReadAll(f)
+ f.Close()
+ if err != nil {
+ errs[e] = err
+ e++
+ continue
+ }
+ _, err = FromWgQuick(string(bytes), "input")
+ if err != nil {
+ errs[e] = err
+ e++
+ continue
+ }
+
+ bytes, err = dpapi.Encrypt(bytes, strings.TrimSuffix(name, configFileUnencryptedSuffix))
+ if err != nil {
+ errs[e] = err
+ e++
+ continue
+ }
+ dstFile := strings.TrimSuffix(path, configFileUnencryptedSuffix) + configFileSuffix
+ if _, err = os.Stat(dstFile); err != nil && !os.IsNotExist(err) {
+ errs[e] = errors.New("Unable to migrate to " + dstFile + " as it already exists")
+ e++
+ continue
+ }
+ err = ioutil.WriteFile(dstFile, bytes, 0600)
+ if err != nil {
+ errs[e] = err
+ e++
+ continue
+ }
+ err = os.Remove(path)
+ if err != nil && os.Remove(dstFile) == nil {
+ errs[e] = err
+ e++
+ continue
+ }
+ i++
+ }
+ return i, errs[:e]
+}
+
+func LoadFromName(name string) (*Config, error) {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return nil, err
+ }
+ return LoadFromPath(filepath.Join(configFileDir, name+configFileSuffix))
+}
+
+func LoadFromPath(path string) (*Config, error) {
+ name, err := NameFromPath(path)
+ if err != nil {
+ return nil, err
+ }
+ bytes, err := ioutil.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ if strings.HasSuffix(path, configFileSuffix) {
+ bytes, err = dpapi.Decrypt(bytes, name)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return FromWgQuick(string(bytes), name)
+}
+
+func NameFromPath(path string) (string, error) {
+ name := filepath.Base(path)
+ if !((len(name) > len(configFileSuffix) && strings.HasSuffix(name, configFileSuffix)) ||
+ (len(name) > len(configFileUnencryptedSuffix) && strings.HasSuffix(name, configFileUnencryptedSuffix))) {
+ return "", errors.New("Path must end in either " + configFileSuffix + " or " + configFileUnencryptedSuffix)
+ }
+ if strings.HasSuffix(path, configFileSuffix) {
+ name = strings.TrimSuffix(name, configFileSuffix)
+ } else {
+ name = strings.TrimSuffix(name, configFileUnencryptedSuffix)
+ }
+ return name, nil
+}
+
+func (config *Config) Save() error {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return err
+ }
+ filename := filepath.Join(configFileDir, config.Name+configFileSuffix)
+ bytes := []byte(config.ToWgQuick())
+ bytes, err = dpapi.Encrypt(bytes, config.Name)
+ if err != nil {
+ return err
+ }
+ err = ioutil.WriteFile(filename+".tmp", bytes, 0600)
+ if err != nil {
+ return err
+ }
+ err = os.Rename(filename+".tmp", filename)
+ if err != nil {
+ os.Remove(filename + ".tmp")
+ return err
+ }
+ return nil
+}
+
+func (config *Config) Path() (string, error) {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(configFileDir, config.Name+configFileSuffix), nil
+}
+
+func DeleteName(name string) error {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return err
+ }
+ return os.Remove(filepath.Join(configFileDir, name+configFileSuffix))
+}
+
+func (config *Config) Delete() error {
+ return DeleteName(config.Name)
+}
diff --git a/conf/store_test.go b/conf/store_test.go
new file mode 100644
index 00000000..fdef7ea7
--- /dev/null
+++ b/conf/store_test.go
@@ -0,0 +1,91 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestStorage(t *testing.T) {
+ c, err := FromWgQuick(testInput, "golangTest")
+ if err != nil {
+ t.Errorf("Unable to parse test config: %s", err.Error())
+ return
+ }
+
+ err = c.Save()
+ if err != nil {
+ t.Errorf("Unable to save config: %s", err.Error())
+ }
+
+ configs, err := ListConfigNames()
+ if err != nil {
+ t.Errorf("Unable to list configs: %s", err.Error())
+ }
+
+ found := false
+ for _, name := range configs {
+ if name == "golangTest" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("Unable to find saved config in list")
+ }
+
+ loaded, err := LoadFromName("golangTest")
+ if err != nil {
+ t.Errorf("Unable to load config: %s", err.Error())
+ return
+ }
+
+ if !reflect.DeepEqual(loaded, c) {
+ t.Error("Loaded config is not the same as saved config")
+ }
+
+ k, err := NewPrivateKey()
+ if err != nil {
+ t.Errorf("Unable to generate new private key: %s", err.Error())
+ }
+ c.Interface.PrivateKey = *k
+
+ err = c.Save()
+ if err != nil {
+ t.Errorf("Unable to save config a second time: %s", err.Error())
+ }
+
+ loaded, err = LoadFromName("golangTest")
+ if err != nil {
+ t.Errorf("Unable to load config a second time: %s", err.Error())
+ return
+ }
+
+ if !reflect.DeepEqual(loaded, c) {
+ t.Error("Second loaded config is not the same as second saved config")
+ }
+
+ err = DeleteName("golangTest")
+ if err != nil {
+ t.Errorf("Unable to delete config: %s", err.Error())
+ }
+
+ configs, err = ListConfigNames()
+ if err != nil {
+ t.Errorf("Unable to list configs: %s", err.Error())
+ }
+ found = false
+ for _, name := range configs {
+ if name == "golangTest" {
+ found = true
+ break
+ }
+ }
+ if found {
+ t.Error("Config wasn't actually deleted")
+ }
+}
diff --git a/conf/storewatcher.go b/conf/storewatcher.go
new file mode 100644
index 00000000..4f9c2ef7
--- /dev/null
+++ b/conf/storewatcher.go
@@ -0,0 +1,38 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import "reflect"
+
+type StoreCallback func()
+
+var storeCallbacks []StoreCallback
+
+func RegisterStoreChangeCallback(cb StoreCallback) {
+ startWatchingConfigDir()
+ cb()
+ storeCallbacks = append(storeCallbacks, cb)
+}
+
+func UnregisterStoreChangeCallback(cb StoreCallback) {
+ //TODO: this function is ridiculous, doing slow iteration like this and reflection too.
+
+ index := -1
+ for i, e := range storeCallbacks {
+ if reflect.ValueOf(e).Pointer() == reflect.ValueOf(cb).Pointer() {
+ index = i
+ break
+ }
+ }
+ if index == -1 {
+ return
+ }
+ newList := storeCallbacks[0:index]
+ if index < len(storeCallbacks)-1 {
+ newList = append(newList, storeCallbacks[index+1:]...)
+ }
+ storeCallbacks = newList
+}
diff --git a/conf/storewatcher_windows.go b/conf/storewatcher_windows.go
new file mode 100644
index 00000000..f3f38fef
--- /dev/null
+++ b/conf/storewatcher_windows.go
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "golang.org/x/sys/windows"
+ "log"
+)
+
+const (
+ fncFILE_NAME uint32 = 0x00000001
+ fncDIR_NAME uint32 = 0x00000002
+ fncATTRIBUTES uint32 = 0x00000004
+ fncSIZE uint32 = 0x00000008
+ fncLAST_WRITE uint32 = 0x00000010
+ fncLAST_ACCESS uint32 = 0x00000020
+ fncCREATION uint32 = 0x00000040
+ fncSECURITY uint32 = 0x00000100
+)
+
+//sys findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) = kernel32.FindFirstChangeNotificationW
+//sys findNextChangeNotification(handle windows.Handle) (err error) = kernel32.FindNextChangeNotification
+
+var haveStartedWatchingConfigDir bool
+
+func startWatchingConfigDir() {
+ if haveStartedWatchingConfigDir {
+ return
+ }
+ haveStartedWatchingConfigDir = true
+ go func() {
+ configFileDir, err := resolveConfigFileDir()
+ if err != nil {
+ return
+ }
+ h, err := findFirstChangeNotification(windows.StringToUTF16Ptr(configFileDir), true, fncFILE_NAME|fncDIR_NAME|fncATTRIBUTES|fncSIZE|fncLAST_WRITE|fncLAST_ACCESS|fncCREATION|fncSECURITY)
+ if err != nil {
+ log.Fatalf("Unable to monitor config directory: %v", err)
+ }
+ for {
+ s, err := windows.WaitForSingleObject(h, windows.INFINITE)
+ if err != nil || s == windows.WAIT_FAILED {
+ log.Fatalf("Unable to wait on config directory watcher: %v", err)
+ }
+
+ for _, cb := range storeCallbacks {
+ cb()
+ }
+
+ err = findNextChangeNotification(h)
+ if err != nil {
+ log.Fatalf("Unable to monitor config directory again: %v", err)
+ }
+ }
+ }()
+}
diff --git a/conf/writer.go b/conf/writer.go
new file mode 100644
index 00000000..642d14a7
--- /dev/null
+++ b/conf/writer.go
@@ -0,0 +1,125 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+)
+
+func (conf *Config) ToWgQuick() string {
+ var output strings.Builder
+ output.WriteString("[Interface]\n")
+
+ output.WriteString(fmt.Sprintf("PrivateKey = %s\n", conf.Interface.PrivateKey.String()))
+
+ if conf.Interface.ListenPort > 0 {
+ output.WriteString(fmt.Sprintf("ListenPort = %d\n", conf.Interface.ListenPort))
+ }
+
+ if len(conf.Interface.Addresses) > 0 {
+ addrStrings := make([]string, len(conf.Interface.Addresses))
+ for i, address := range conf.Interface.Addresses {
+ addrStrings[i] = address.String()
+ }
+ output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", ")))
+ }
+
+ if len(conf.Interface.Dns) > 0 {
+ addrStrings := make([]string, len(conf.Interface.Dns))
+ for i, address := range conf.Interface.Dns {
+ addrStrings[i] = address.String()
+ }
+ output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", ")))
+ }
+
+ if conf.Interface.Mtu > 0 {
+ output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.Mtu))
+ }
+
+ for _, peer := range conf.Peers {
+ output.WriteString("\n[Peer]\n")
+
+ output.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey.String()))
+
+ if !peer.PresharedKey.IsZero() {
+ output.WriteString(fmt.Sprintf("PresharedKey = %s\n", peer.PresharedKey.String()))
+ }
+
+ if len(peer.AllowedIPs) > 0 {
+ addrStrings := make([]string, len(peer.AllowedIPs))
+ for i, address := range peer.AllowedIPs {
+ addrStrings[i] = address.String()
+ }
+ output.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(addrStrings[:], ", ")))
+ }
+
+ if !peer.Endpoint.IsEmpty() {
+ output.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint.String()))
+ }
+
+ if peer.PersistentKeepalive > 0 {
+ output.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", peer.PersistentKeepalive))
+ }
+ }
+ return output.String()
+}
+
+func (conf *Config) ToUAPI() (string, error) {
+ var output strings.Builder
+ output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString()))
+
+ if conf.Interface.ListenPort > 0 {
+ output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort))
+ }
+
+ if len(conf.Peers) > 0 {
+ output.WriteString("replace_peers=true\n")
+ }
+
+ for _, peer := range conf.Peers {
+ output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString()))
+
+ if !peer.PresharedKey.IsZero() {
+ output.WriteString(fmt.Sprintf("preshared_key = %s\n", peer.PresharedKey.String()))
+ }
+
+ if !peer.Endpoint.IsEmpty() {
+ ips, err := net.LookupIP(peer.Endpoint.Host)
+ if err != nil {
+ return "", err
+ }
+ var ip net.IP
+ for _, iterip := range ips {
+ iterip = iterip.To4()
+ if iterip != nil {
+ ip = iterip
+ break
+ }
+ if ip == nil {
+ ip = iterip
+ }
+ }
+ if ip == nil {
+ return "", errors.New("Unable to resolve IP address of endpoint")
+ }
+ resolvedEndpoint := Endpoint{ip.String(), peer.Endpoint.Port}
+ output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String()))
+ }
+
+ output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive))
+
+ if len(peer.AllowedIPs) > 0 {
+ output.WriteString("replace_allowed_ips=true\n")
+ for _, address := range peer.AllowedIPs {
+ output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String()))
+ }
+ }
+ }
+ return output.String(), nil
+}
diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go
new file mode 100644
index 00000000..64bec1f4
--- /dev/null
+++ b/conf/zsyscall_windows.go
@@ -0,0 +1,96 @@
+// Code generated by 'go generate'; DO NOT EDIT.
+
+package conf
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+ errnoERROR_IO_PENDING = 997
+)
+
+var (
+ errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+ switch e {
+ case 0:
+ return nil
+ case errnoERROR_IO_PENDING:
+ return errERROR_IO_PENDING
+ }
+ // TODO: add more here, after collecting data on the common
+ // error values see on Windows. (perhaps when running
+ // all.bat?)
+ return e
+}
+
+var (
+ modole32 = windows.NewLazySystemDLL("ole32.dll")
+ modshell32 = windows.NewLazySystemDLL("shell32.dll")
+ modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+
+ procCoTaskMemFree = modole32.NewProc("CoTaskMemFree")
+ procSHGetKnownFolderPath = modshell32.NewProc("SHGetKnownFolderPath")
+ procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW")
+ procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification")
+)
+
+func coTaskMemFree(pointer uintptr) {
+ syscall.Syscall(procCoTaskMemFree.Addr(), 1, uintptr(pointer), 0, 0)
+ return
+}
+
+func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) {
+ r1, _, e1 := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0)
+ if r1 != 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) {
+ var _p0 uint32
+ if watchSubtree {
+ _p0 = 1
+ } else {
+ _p0 = 0
+ }
+ r0, _, e1 := syscall.Syscall(procFindFirstChangeNotificationW.Addr(), 3, uintptr(unsafe.Pointer(path)), uintptr(_p0), uintptr(filter))
+ handle = windows.Handle(r0)
+ if handle == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func findNextChangeNotification(handle windows.Handle) (err error) {
+ r1, _, e1 := syscall.Syscall(procFindNextChangeNotification.Addr(), 1, uintptr(handle), 0, 0)
+ if r1 == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}