aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/updater/downloader.go
diff options
context:
space:
mode:
Diffstat (limited to 'updater/downloader.go')
-rw-r--r--updater/downloader.go94
1 files changed, 51 insertions, 43 deletions
diff --git a/updater/downloader.go b/updater/downloader.go
index fad9534e..bf28db54 100644
--- a/updater/downloader.go
+++ b/updater/downloader.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package updater
@@ -11,13 +11,12 @@ import (
"fmt"
"hash"
"io"
- "net/http"
- "os"
"sync/atomic"
"golang.org/x/crypto/blake2b"
"golang.zx2c4.com/wireguard/windows/elevate"
+ "golang.zx2c4.com/wireguard/windows/updater/winhttp"
"golang.zx2c4.com/wireguard/windows/version"
)
@@ -48,30 +47,55 @@ type UpdateFound struct {
hash [blake2b.Size256]byte
}
-func CheckForUpdate() (*UpdateFound, error) {
+func CheckForUpdate() (updateFound *UpdateFound, err error) {
+ updateFound, _, _, err = checkForUpdate(false)
+ return
+}
+
+func checkForUpdate(keepSession bool) (*UpdateFound, *winhttp.Session, *winhttp.Connection, error) {
if !version.IsRunningOfficialVersion() {
- return nil, errors.New("Build is not official, so updates are disabled")
+ return nil, nil, nil, errors.New("Build is not official, so updates are disabled")
}
- request, err := http.NewRequest(http.MethodGet, latestVersionURL, nil)
+ session, err := winhttp.NewSession(version.UserAgent())
if err != nil {
- return nil, err
+ return nil, nil, nil, err
}
- request.Header.Add("User-Agent", version.UserAgent())
- response, err := http.DefaultClient.Do(request)
+ defer func() {
+ if err != nil || !keepSession {
+ session.Close()
+ }
+ }()
+ connection, err := session.Connect(updateServerHost, updateServerPort, updateServerUseHttps)
if err != nil {
- return nil, err
+ return nil, nil, nil, err
}
- defer response.Body.Close()
+ defer func() {
+ if err != nil || !keepSession {
+ connection.Close()
+ }
+ }()
+ response, err := connection.Get(latestVersionPath, true)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ defer response.Close()
var fileList [1024 * 512] /* 512 KiB */ byte
- bytesRead, err := response.Body.Read(fileList[:])
+ bytesRead, err := response.Read(fileList[:])
if err != nil && (err != io.EOF || bytesRead == 0) {
- return nil, err
+ return nil, nil, nil, err
}
files, err := readFileList(fileList[:bytesRead])
if err != nil {
- return nil, err
+ return nil, nil, nil, err
+ }
+ updateFound, err := findCandidate(files)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ if keepSession {
+ return updateFound, session, connection, nil
}
- return findCandidate(files)
+ return updateFound, nil, nil, nil
}
var updateInProgress = uint32(0)
@@ -89,11 +113,13 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
defer atomic.StoreUint32(&updateInProgress, 0)
progress <- DownloadProgress{Activity: "Checking for update"}
- update, err := CheckForUpdate()
+ update, session, connection, err := checkForUpdate(true)
if err != nil {
progress <- DownloadProgress{Error: err}
return
}
+ defer connection.Close()
+ defer session.Close()
if update == nil {
progress <- DownloadProgress{Error: errors.New("No update was found")}
return
@@ -108,31 +134,21 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
progress <- DownloadProgress{Activity: fmt.Sprintf("Msi destination is %#q", file.Name())}
defer func() {
if file != nil {
- name := file.Name()
- file.Seek(0, io.SeekStart)
- file.Truncate(0)
- file.Close()
- os.Remove(name) // TODO: Do we have any sort of TOCTOU here?
+ file.Delete()
}
}()
dp := DownloadProgress{Activity: "Downloading update"}
progress <- dp
- request, err := http.NewRequest(http.MethodGet, fmt.Sprintf(msiURL, update.name), nil)
+ response, err := connection.Get(fmt.Sprintf(msiPath, update.name), false)
if err != nil {
progress <- DownloadProgress{Error: err}
return
}
- request.Header.Add("User-Agent", version.UserAgent())
- request.Header.Set("Accept-Encoding", "identity")
- response, err := http.DefaultClient.Do(request)
- if err != nil {
- progress <- DownloadProgress{Error: err}
- return
- }
- defer response.Body.Close()
- if response.ContentLength >= 0 {
- dp.BytesTotal = uint64(response.ContentLength)
+ defer response.Close()
+ length, err := response.Length()
+ if err == nil && length >= 0 {
+ dp.BytesTotal = length
progress <- dp
}
hasher, err := blake2b.New256(nil)
@@ -141,7 +157,7 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
return
}
pm := &progressHashWatcher{&dp, progress, hasher}
- _, err = io.Copy(file, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm))
+ _, err = io.Copy(file, io.TeeReader(io.LimitReader(response, 1024*1024*100 /* 100 MiB */), pm))
if err != nil {
progress <- DownloadProgress{Error: err}
return
@@ -151,22 +167,14 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
return
}
- // TODO: it would be nice to rename in place from "file.msi.unverified" to "file.msi", but Windows TOCTOU stuff
- // is hard, so we'll come back to this later.
- name := file.Name()
- file.Close()
- file = nil
-
progress <- DownloadProgress{Activity: "Verifying authenticode signature"}
- if !version.VerifyAuthenticode(name) {
- os.Remove(name) // TODO: Do we have any sort of TOCTOU here?
+ if !verifyAuthenticode(file.ExclusivePath()) {
progress <- DownloadProgress{Error: errors.New("The downloaded update does not have an authentic authenticode signature")}
return
}
progress <- DownloadProgress{Activity: "Installing update"}
- err = runMsi(name, userToken)
- os.Remove(name) // TODO: Do we have any sort of TOCTOU here?
+ err = runMsi(file, userToken)
if err != nil {
progress <- DownloadProgress{Error: err}
return