diff options
Diffstat (limited to '')
-rw-r--r-- | updater/downloader.go | 32 |
1 files changed, 11 insertions, 21 deletions
diff --git a/updater/downloader.go b/updater/downloader.go index a12b5037..6e606532 100644 --- a/updater/downloader.go +++ b/updater/downloader.go @@ -11,11 +11,12 @@ import ( "fmt" "hash" "io" - "net/http" "os" "sync/atomic" "golang.org/x/crypto/blake2b" + + "golang.zx2c4.com/wireguard/windows/updater/winhttp" "golang.zx2c4.com/wireguard/windows/version" ) @@ -47,18 +48,13 @@ type UpdateFound struct { } func CheckForUpdate() (*UpdateFound, error) { - request, err := http.NewRequest(http.MethodGet, latestVersionURL, nil) - if err != nil { - return nil, err - } - request.Header.Add("User-Agent", version.UserAgent()) - response, err := http.DefaultClient.Do(request) + response, err := winhttp.Get(version.UserAgent(), latestVersionURL) if err != nil { return nil, err } - defer response.Body.Close() + defer response.Close() var fileList [1024 * 512] /* 512 KiB */ byte - bytesRead, err := response.Body.Read(fileList[:]) + bytesRead, err := response.Read(fileList[:]) if err != nil && (err != io.EOF || bytesRead == 0) { return nil, err } @@ -112,21 +108,15 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress dp := DownloadProgress{Activity: "Downloading update"} progress <- dp - request, err := http.NewRequest(http.MethodGet, fmt.Sprintf(msiURL, update.name), nil) - if err != nil { - progress <- DownloadProgress{Error: err} - return - } - request.Header.Add("User-Agent", version.UserAgent()) - request.Header.Set("Accept-Encoding", "identity") - response, err := http.DefaultClient.Do(request) + response, err := winhttp.Get(version.UserAgent(), fmt.Sprintf(msiURL, update.name)) if err != nil { progress <- DownloadProgress{Error: err} return } - defer response.Body.Close() - if response.ContentLength >= 0 { - dp.BytesTotal = uint64(response.ContentLength) + defer response.Close() + length, err := response.Length() + if err == nil && length >= 0 { + dp.BytesTotal = length progress <- dp } hasher, err := blake2b.New256(nil) @@ -135,7 +125,7 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress return } pm := &progressHashWatcher{&dp, progress, hasher} - _, err = io.Copy(file, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm)) + _, err = io.Copy(file, io.TeeReader(io.LimitReader(response, 1024*1024*100 /* 100 MiB */), pm)) if err != nil { progress <- DownloadProgress{Error: err} return |