aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/updater
diff options
context:
space:
mode:
Diffstat (limited to 'updater')
-rw-r--r--updater/constants.go14
-rw-r--r--updater/downloader.go181
-rw-r--r--updater/msirunner_linux.go21
-rw-r--r--updater/msirunner_windows.go32
-rw-r--r--updater/signify.go73
-rw-r--r--updater/updater_test.go41
-rw-r--r--updater/versions.go85
7 files changed, 447 insertions, 0 deletions
diff --git a/updater/constants.go b/updater/constants.go
new file mode 100644
index 00000000..ae3988bd
--- /dev/null
+++ b/updater/constants.go
@@ -0,0 +1,14 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+const (
+ releasePublicKeyBase64 = "RWQGxwD+15iPpnPCEijYJ3CWYFgojWwBJZNg0OnJfICVu/CfyKeQ0vIA"
+ latestVersionURL = "https://download.wireguard.com/windows-client/latest.sig"
+ msiURL = "https://download.wireguard.com/windows-client/%s"
+ msiArchPrefix = "wireguard-%s-"
+ msiSuffix = ".msi"
+)
diff --git a/updater/downloader.go b/updater/downloader.go
new file mode 100644
index 00000000..ea3ee9d4
--- /dev/null
+++ b/updater/downloader.go
@@ -0,0 +1,181 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "crypto/hmac"
+ "errors"
+ "fmt"
+ "golang.org/x/crypto/blake2b"
+ "golang.zx2c4.com/wireguard/windows/version"
+ "hash"
+ "io"
+ "net/http"
+ "os"
+ "path"
+ "sync/atomic"
+)
+
+type DownloadProgress struct {
+ Activity string
+ BytesDownloaded uint64
+ BytesTotal uint64
+ Error error
+ Complete bool
+}
+
+type progressHashWatcher struct {
+ dp *DownloadProgress
+ c chan DownloadProgress
+ hashState hash.Hash
+}
+
+func (pm *progressHashWatcher) Write(p []byte) (int, error) {
+ bytes := len(p)
+ pm.dp.BytesDownloaded += uint64(bytes)
+ pm.c <- *pm.dp
+ pm.hashState.Write(p)
+ return bytes, nil
+}
+
+type UpdateFound struct {
+ name string
+ hash [blake2b.Size256]byte
+}
+
+func CheckForUpdate() (*UpdateFound, error) {
+ request, err := http.NewRequest(http.MethodGet, latestVersionURL, nil)
+ if err != nil {
+ return nil, err
+ }
+ request.Header.Add("User-Agent", version.UserAgent())
+ response, err := http.DefaultClient.Do(request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+ var fileList [1024 * 512] /* 512 KiB */ byte
+ bytesRead, err := response.Body.Read(fileList[:])
+ if err != nil && (err != io.EOF || bytesRead == 0) {
+ return nil, err
+ }
+ files, err := readFileList(fileList[:bytesRead])
+ if err != nil {
+ return nil, err
+ }
+ return findCandidate(files)
+}
+
+var updateInProgress = uint32(0)
+
+func DownloadVerifyAndExecute() (progress chan DownloadProgress) {
+ progress = make(chan DownloadProgress, 128)
+ progress <- DownloadProgress{Activity: "Initializing"}
+
+ if !atomic.CompareAndSwapUint32(&updateInProgress, 0, 1) {
+ progress <- DownloadProgress{Error: errors.New("An update is already in progress")}
+ return
+ }
+
+ go func() {
+ defer atomic.StoreUint32(&updateInProgress, 0)
+
+ progress <- DownloadProgress{Activity: "Rechecking for update"}
+ update, err := CheckForUpdate()
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ if update == nil {
+ progress <- DownloadProgress{Error: errors.New("No update was found when re-checking for updates")}
+ return
+ }
+
+ progress <- DownloadProgress{Activity: "Creating update file"}
+ updateDir, err := msiSaveDirectory()
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ // Clean up old updates the brutal way:
+ os.RemoveAll(updateDir)
+
+ err = os.MkdirAll(updateDir, 0700)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ destinationFilename := path.Join(updateDir, update.name)
+ unverifiedDestinationFilename := destinationFilename + ".unverified"
+ out, err := os.Create(unverifiedDestinationFilename)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ defer func() {
+ if out != nil {
+ out.Seek(0, io.SeekStart)
+ out.Truncate(0)
+ out.Close()
+ os.Remove(unverifiedDestinationFilename)
+ }
+ }()
+
+ dp := DownloadProgress{Activity: "Downloading update"}
+ progress <- dp
+ request, err := http.NewRequest(http.MethodGet, fmt.Sprintf(msiURL, update.name), nil)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ request.Header.Add("User-Agent", version.UserAgent())
+ request.Header.Set("Accept-Encoding", "identity")
+ response, err := http.DefaultClient.Do(request)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ defer response.Body.Close()
+ if response.ContentLength >= 0 {
+ dp.BytesTotal = uint64(response.ContentLength)
+ progress <- dp
+ }
+ hasher, err := blake2b.New256(nil)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ pm := &progressHashWatcher{&dp, progress, hasher}
+ _, err = io.Copy(out, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm))
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ if !hmac.Equal(hasher.Sum(nil), update.hash[:]) {
+ progress <- DownloadProgress{Error: errors.New("The downloaded update has the wrong hash")}
+ return
+ }
+ out.Close()
+ out = nil
+ err = os.Rename(unverifiedDestinationFilename, destinationFilename)
+ if err != nil {
+ os.Remove(unverifiedDestinationFilename)
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+ progress <- DownloadProgress{Activity: "Installing update"}
+ err = runMsi(destinationFilename)
+ os.Remove(unverifiedDestinationFilename)
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ return
+ }
+
+ progress <- DownloadProgress{Complete: true}
+ }()
+
+ return progress
+}
diff --git a/updater/msirunner_linux.go b/updater/msirunner_linux.go
new file mode 100644
index 00000000..974c0883
--- /dev/null
+++ b/updater/msirunner_linux.go
@@ -0,0 +1,21 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+// This isn't a Linux program, yes, but having the updater package work across platforms is quite helpful for testing.
+
+func runMsi(msiPath string) error {
+ return exec.Command("qarma", "--info", "--text", fmt.Sprintf("It seems to be working! Were we on Windows, \"%s\" would be executed.", msiPath)).Run()
+}
+
+func msiSaveDirectory() (string, error) {
+ return "/tmp/wireguard-update-test-msi-directory", nil
+}
diff --git a/updater/msirunner_windows.go b/updater/msirunner_windows.go
new file mode 100644
index 00000000..a498af3c
--- /dev/null
+++ b/updater/msirunner_windows.go
@@ -0,0 +1,32 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "os/exec"
+ "path"
+)
+
+func runMsi(msiPath string) error {
+ system32, err := windows.GetSystemDirectory()
+ if err != nil {
+ return err
+ }
+ // BUG: The Go documentation says that its built-in shell quoting isn't good for msiexec.exe.
+ // See https://github.com/golang/go/issues/15566. But perhaps our limited set of options
+ // actually works fine? Investigate this!
+ return exec.Command(path.Join(system32, "msiexec.exe"), "/quiet", "/i", msiPath).Run()
+}
+
+func msiSaveDirectory() (string, error) {
+ configRootDir, err := conf.RootDirectory()
+ if err != nil {
+ return "", err
+ }
+ return path.Join(configRootDir, "Updates"), nil
+}
diff --git a/updater/signify.go b/updater/signify.go
new file mode 100644
index 00000000..d4605cbb
--- /dev/null
+++ b/updater/signify.go
@@ -0,0 +1,73 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/hex"
+ "errors"
+ "golang.org/x/crypto/blake2b"
+ "golang.org/x/crypto/ed25519"
+ "strings"
+)
+
+/*
+ * Generate with:
+ * $ b2sum -l 256 *.msi > list
+ * $ signify -S -e -s release.sec -m list
+ * $ upload ./list.sec
+ */
+
+type fileList map[string][blake2b.Size256]byte
+
+func readFileList(input []byte) (fileList, error) {
+ publicKeyBytes, err := base64.StdEncoding.DecodeString(releasePublicKeyBase64)
+ if err != nil || len(publicKeyBytes) != ed25519.PublicKeySize+10 || publicKeyBytes[0] != 'E' || publicKeyBytes[1] != 'd' {
+ return nil, errors.New("Invalid public key")
+ }
+ publicKeyBytes = publicKeyBytes[10:]
+ lines := bytes.SplitN(input, []byte{'\n'}, 3)
+ if len(lines) != 3 {
+ return nil, errors.New("Signature input has too few lines")
+ }
+ if !bytes.HasPrefix(lines[0], []byte("untrusted comment: ")) {
+ return nil, errors.New("Signature input is missing untrusted comment")
+ }
+ signatureBytes, err := base64.StdEncoding.DecodeString(string(lines[1]))
+ if err != nil {
+ return nil, errors.New("Signature input is not valid base64")
+ }
+ if len(signatureBytes) != ed25519.SignatureSize+10 || signatureBytes[0] != 'E' || signatureBytes[1] != 'd' {
+ return nil, errors.New("Signature input bytes are incorrect length or represent invalid signature type")
+ }
+ signatureBytes = signatureBytes[10:]
+ if !ed25519.Verify(publicKeyBytes, lines[2], signatureBytes) {
+ return nil, errors.New("Signature is invalid")
+ }
+ fileLines := strings.Split(string(lines[2]), "\n")
+ fileHashes := make(map[string][blake2b.Size256]byte, len(fileLines))
+ for index, line := range fileLines {
+ if len(line) == 0 && index == len(fileLines)-1 {
+ break
+ }
+ components := strings.SplitN(line, " ", 2)
+ if len(components) != 2 {
+ return nil, errors.New("File hash line has too few components")
+ }
+ maybeHash, err := hex.DecodeString(components[0])
+ if err != nil || len(maybeHash) != blake2b.Size256 {
+ return nil, errors.New("File hash is invalid base64 or incorrect number of bytes")
+ }
+ var hash [blake2b.Size256]byte
+ copy(hash[:], maybeHash)
+ fileHashes[components[1]] = hash
+ }
+ if len(fileHashes) == 0 {
+ return nil, errors.New("No file hashes found in signed input")
+ }
+ return fileHashes, nil
+}
diff --git a/updater/updater_test.go b/updater/updater_test.go
new file mode 100644
index 00000000..7bc4df8e
--- /dev/null
+++ b/updater/updater_test.go
@@ -0,0 +1,41 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "testing"
+)
+
+func TestUpdate(t *testing.T) {
+ update, err := CheckForUpdate()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if update == nil {
+ t.Error("No update available")
+ return
+ }
+ t.Log("Found update")
+ progress := DownloadVerifyAndExecute()
+ for {
+ dp := <-progress
+ if dp.Error != nil {
+ t.Error(dp.Error)
+ return
+ }
+ if len(dp.Activity) > 0 {
+ t.Log(dp.Activity)
+ }
+ if dp.BytesTotal > 0 {
+ t.Logf("Downloaded %d of %d", dp.BytesDownloaded, dp.BytesTotal)
+ }
+ if dp.Complete {
+ t.Log("Complete!")
+ break
+ }
+ }
+}
diff --git a/updater/versions.go b/updater/versions.go
new file mode 100644
index 00000000..a5b6c258
--- /dev/null
+++ b/updater/versions.go
@@ -0,0 +1,85 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "errors"
+ "fmt"
+ "golang.zx2c4.com/wireguard/windows/version"
+ "runtime"
+ "strconv"
+ "strings"
+)
+
+func versionNewerThanUs(candidate string) (bool, error) {
+ candidateParts := strings.Split(candidate, ".")
+ ourParts := strings.Split(version.WireGuardWindowsVersion, ".")
+ if len(candidateParts) == 0 || len(ourParts) == 0 {
+ return false, errors.New("Empty version")
+ }
+ l := len(candidateParts)
+ if len(ourParts) > l {
+ l = len(ourParts)
+ }
+ for i := 0; i < l; i++ {
+ var err error
+ cP, oP := uint64(0), uint64(0)
+ if i < len(candidateParts) {
+ if len(candidateParts[i]) == 0 {
+ return false, errors.New("Empty version part")
+ }
+ cP, err = strconv.ParseUint(candidateParts[i], 10, 16)
+ if err != nil {
+ return false, errors.New("Invalid version integer part")
+ }
+ }
+ if i < len(ourParts) {
+ if len(ourParts[i]) == 0 {
+ return false, errors.New("Empty version part")
+ }
+ oP, err = strconv.ParseUint(ourParts[i], 10, 16)
+ if err != nil {
+ return false, errors.New("Invalid version integer part")
+ }
+ }
+ if cP == oP {
+ continue
+ }
+ return cP > oP, nil
+ }
+ return false, nil
+}
+
+func findCandidate(candidates fileList) (*UpdateFound, error) {
+ var arch string
+ if runtime.GOARCH == "amd64" {
+ arch = "amd64"
+ } else if runtime.GOARCH == "386" {
+ arch = "x86"
+ } else if runtime.GOARCH == "arm64" {
+ arch = "arm64"
+ } else {
+ return nil, errors.New("Invalid GOARCH")
+ }
+ prefix := fmt.Sprintf(msiArchPrefix, arch)
+ suffix := msiSuffix
+ for name, hash := range candidates {
+ if strings.HasPrefix(name, prefix) && strings.HasSuffix(name, suffix) {
+ version := strings.TrimSuffix(strings.TrimPrefix(name, prefix), suffix)
+ if len(version) > 128 {
+ return nil, errors.New("Version length is too long")
+ }
+ newer, err := versionNewerThanUs(version)
+ if err != nil {
+ return nil, err
+ }
+ if newer {
+ return &UpdateFound{name, hash}, nil
+ }
+ }
+ }
+ return nil, nil
+}