aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/updater/winhttp/winhttp.go
diff options
context:
space:
mode:
Diffstat (limited to 'updater/winhttp/winhttp.go')
-rw-r--r--updater/winhttp/winhttp.go230
1 files changed, 230 insertions, 0 deletions
diff --git a/updater/winhttp/winhttp.go b/updater/winhttp/winhttp.go
new file mode 100644
index 00000000..cb19f194
--- /dev/null
+++ b/updater/winhttp/winhttp.go
@@ -0,0 +1,230 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package winhttp
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type Session struct {
+ handle _HINTERNET
+}
+
+type Connection struct {
+ handle _HINTERNET
+ session *Session
+ https bool
+}
+
+type Response struct {
+ handle _HINTERNET
+ connection *Connection
+}
+
+func convertError(err *error) {
+ if *err == nil {
+ return
+ }
+ var errno windows.Errno
+ if errors.As(*err, &errno) {
+ if errno > _WINHTTP_ERROR_BASE && errno <= _WINHTTP_ERROR_LAST {
+ *err = Error(errno)
+ }
+ }
+}
+
+func isWin7() bool {
+ maj, min, _ := windows.RtlGetNtVersionNumbers()
+ return maj < 6 || (maj == 6 && min <= 1)
+}
+
+func isWin8DotZeroOrBelow() bool {
+ maj, min, _ := windows.RtlGetNtVersionNumbers()
+ return maj < 6 || (maj == 6 && min <= 2)
+}
+
+func NewSession(userAgent string) (session *Session, err error) {
+ session = new(Session)
+ defer convertError(&err)
+ defer func() {
+ if err != nil {
+ session.Close()
+ session = nil
+ }
+ }()
+ userAgent16, err := windows.UTF16PtrFromString(userAgent)
+ if err != nil {
+ return
+ }
+ var proxyFlag uint32 = _WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY
+ if isWin7() {
+ proxyFlag = _WINHTTP_ACCESS_TYPE_DEFAULT_PROXY
+ }
+ session.handle, err = winHttpOpen(userAgent16, proxyFlag, nil, nil, 0)
+ if err != nil {
+ return
+ }
+ var enableHttp2 uint32 = _WINHTTP_PROTOCOL_FLAG_HTTP2
+ _ = winHttpSetOption(session.handle, _WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL, unsafe.Pointer(&enableHttp2), uint32(unsafe.Sizeof(enableHttp2))) // Don't check return value, in case of old Windows
+
+ if isWin8DotZeroOrBelow() {
+ var enableTLS12 uint32 = _WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2
+ err = winHttpSetOption(session.handle, _WINHTTP_OPTION_SECURE_PROTOCOLS, unsafe.Pointer(&enableTLS12), uint32(unsafe.Sizeof(enableTLS12)))
+ if err != nil {
+ return
+ }
+ }
+
+ runtime.SetFinalizer(session, func(session *Session) {
+ session.Close()
+ })
+ return
+}
+
+func (session *Session) Close() (err error) {
+ defer convertError(&err)
+ handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&session.handle), 0))
+ if handle == 0 {
+ return
+ }
+ return winHttpCloseHandle(handle)
+}
+
+func (session *Session) Connect(server string, port uint16, https bool) (connection *Connection, err error) {
+ connection = &Connection{session: session}
+ defer convertError(&err)
+ defer func() {
+ if err != nil {
+ connection.Close()
+ connection = nil
+ }
+ }()
+ server16, err := windows.UTF16PtrFromString(server)
+ if err != nil {
+ return
+ }
+ connection.handle, err = winHttpConnect(session.handle, server16, port, 0)
+ if err != nil {
+ return
+ }
+ connection.https = https
+
+ runtime.SetFinalizer(connection, func(connection *Connection) {
+ connection.Close()
+ })
+ return
+}
+
+func (connection *Connection) Close() (err error) {
+ defer convertError(&err)
+ handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&connection.handle), 0))
+ if handle == 0 {
+ return
+ }
+ return winHttpCloseHandle(handle)
+}
+
+func (connection *Connection) Get(path string, refresh bool) (response *Response, err error) {
+ response = &Response{connection: connection}
+ defer convertError(&err)
+ defer func() {
+ if err != nil {
+ response.Close()
+ response = nil
+ }
+ }()
+ var flags uint32
+ if refresh {
+ flags |= _WINHTTP_FLAG_REFRESH
+ }
+ if connection.https {
+ flags |= _WINHTTP_FLAG_SECURE
+ }
+ path16, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return
+ }
+ get16, err := windows.UTF16PtrFromString("GET")
+ if err != nil {
+ return
+ }
+ response.handle, err = winHttpOpenRequest(connection.handle, get16, path16, nil, nil, nil, flags)
+ if err != nil {
+ return
+ }
+ err = winHttpSendRequest(response.handle, nil, 0, nil, 0, 0, 0)
+ if err != nil {
+ return
+ }
+ err = winHttpReceiveResponse(response.handle, 0)
+ if err != nil {
+ return
+ }
+
+ runtime.SetFinalizer(response, func(response *Response) {
+ response.Close()
+ })
+ return
+}
+
+func (response *Response) Length() (length uint64, err error) {
+ defer convertError(&err)
+ numBuf := make([]uint16, 22)
+ numLen := uint32(len(numBuf) * 2)
+ err = winHttpQueryHeaders(response.handle, _WINHTTP_QUERY_CONTENT_LENGTH, nil, unsafe.Pointer(&numBuf[0]), &numLen, nil)
+ if err != nil {
+ return
+ }
+ length, err = strconv.ParseUint(windows.UTF16ToString(numBuf[:numLen]), 10, 64)
+ if err != nil {
+ return
+ }
+ return
+}
+
+func (response *Response) Read(p []byte) (n int, err error) {
+ defer convertError(&err)
+ if len(p) == 0 {
+ return 0, nil
+ }
+ var bytesRead uint32
+ err = winHttpReadData(response.handle, &p[0], uint32(len(p)), &bytesRead)
+ if err != nil {
+ return 0, nil
+ }
+ if bytesRead == 0 || int(bytesRead) < 0 {
+ return 0, io.EOF
+ }
+ return int(bytesRead), nil
+}
+
+func (response *Response) Close() (err error) {
+ defer convertError(&err)
+ handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&response.handle), 0))
+ if handle == 0 {
+ return
+ }
+ return winHttpCloseHandle(handle)
+}
+
+func (error Error) Error() string {
+ var message [2048]uint16
+ n, err := windows.FormatMessage(windows.FORMAT_MESSAGE_FROM_HMODULE|windows.FORMAT_MESSAGE_IGNORE_INSERTS|windows.FORMAT_MESSAGE_MAX_WIDTH_MASK,
+ modwinhttp.Handle(), uint32(error), 0, message[:], nil)
+ if err != nil {
+ return fmt.Sprintf("WinHTTP error #%d", error)
+ }
+ return strings.TrimSpace(windows.UTF16ToString(message[:n]))
+}