aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-09-23 14:54:20 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-09-23 15:29:18 +0200
commitf2c1f2a478a36cf94d54c6eee4c9addc34f10457 (patch)
tree29a07822f3ffea086b03a5bcb0d0c9a6954b7cff
parentupdater: use winhttp to reduce filesize (diff)
downloadwireguard-windows-f2c1f2a478a36cf94d54c6eee4c9addc34f10457.tar.xz
wireguard-windows-f2c1f2a478a36cf94d54c6eee4c9addc34f10457.zip
version: use crypt32 instead of go x509 for cn extraction for file size
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--version/official_windows.go6
-rw-r--r--version/wintrust/certificate_test.go28
-rw-r--r--version/wintrust/certificate_windows.go39
-rw-r--r--version/wintrust/zsyscall_windows.go13
4 files changed, 65 insertions, 21 deletions
diff --git a/version/official_windows.go b/version/official_windows.go
index b0f62250..d9f041f6 100644
--- a/version/official_windows.go
+++ b/version/official_windows.go
@@ -67,12 +67,12 @@ func IsRunningOfficialVersion() bool {
// This below tests is easily circumvented. False certificates can be appended, and just checking the
// common name is not very good. But that's okay, as this isn't security related.
- certs, err := wintrust.ExtractCertificates(path)
+ names, err := wintrust.ExtractCertificateNames(path)
if err != nil {
return false
}
- for _, cert := range certs {
- if cert.Subject.CommonName == officialCommonName {
+ for _, name := range names {
+ if name == officialCommonName {
return true
}
}
diff --git a/version/wintrust/certificate_test.go b/version/wintrust/certificate_test.go
new file mode 100644
index 00000000..19007351
--- /dev/null
+++ b/version/wintrust/certificate_test.go
@@ -0,0 +1,28 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wintrust
+
+import (
+ "fmt"
+ "path/filepath"
+ "testing"
+
+ "golang.org/x/sys/windows"
+)
+
+func TestExtractCertificateNames(t *testing.T) {
+ system32, err := windows.GetSystemDirectory()
+ if err != nil {
+ t.Fatal(err)
+ }
+ names, err := ExtractCertificateNames(filepath.Join(system32, "ntoskrnl.exe"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, name := range names {
+ fmt.Printf("%d: %s\n", i, name)
+ }
+}
diff --git a/version/wintrust/certificate_windows.go b/version/wintrust/certificate_windows.go
index cf254e4f..8c933f11 100644
--- a/version/wintrust/certificate_windows.go
+++ b/version/wintrust/certificate_windows.go
@@ -6,7 +6,6 @@
package wintrust
import (
- "crypto/x509"
"syscall"
"unsafe"
@@ -14,26 +13,28 @@ import (
)
const (
- CERT_QUERY_OBJECT_FILE = 1
- CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED = 1024
- CERT_QUERY_FORMAT_FLAG_ALL = 14
+ _CERT_QUERY_OBJECT_FILE = 1
+ _CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED = 1024
+ _CERT_QUERY_FORMAT_FLAG_ALL = 14
+ _CERT_NAME_SIMPLE_DISPLAY_TYPE = 4
)
-//sys CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) = crypt32.CryptQueryObject
+//sys cryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) = crypt32.CryptQueryObject
+//sys certGetNameString(certContext *windows.CertContext, nameType uint32, flags uint32, typePara uintptr, name *uint16, size uint32) (chars uint32) = crypt32.CertGetNameStringW
-func ExtractCertificates(path string) ([]x509.Certificate, error) {
+func ExtractCertificateNames(path string) ([]string, error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, err
}
var certStore windows.Handle
- err = CryptQueryObject(CERT_QUERY_OBJECT_FILE, uintptr(unsafe.Pointer(path16)), CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil)
+ err = cryptQueryObject(_CERT_QUERY_OBJECT_FILE, uintptr(unsafe.Pointer(path16)), _CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, _CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil)
if err != nil {
return nil, err
}
defer windows.CertCloseStore(certStore, 0)
- var certs []x509.Certificate
var cert *windows.CertContext
+ var names []string
for {
cert, err = windows.CertEnumCertificatesInStore(certStore, cert)
if err != nil {
@@ -47,13 +48,21 @@ func ExtractCertificates(path string) ([]x509.Certificate, error) {
if cert == nil {
break
}
- buf := make([]byte, cert.Length)
- copy(buf, (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:])
- if c, err := x509.ParseCertificate(buf); err == nil {
- certs = append(certs, *c)
- } else {
- return nil, err
+ nameLen := certGetNameString(cert, _CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, 0, nil, 0)
+ if nameLen == 0 {
+ continue
+ }
+ name16 := make([]uint16, nameLen)
+ if certGetNameString(cert, _CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, 0, &name16[0], nameLen) != nameLen {
+ continue
}
+ if name16[0] == 0 {
+ continue
+ }
+ names = append(names, windows.UTF16ToString(name16))
+ }
+ if names == nil {
+ return nil, syscall.Errno(windows.CRYPT_E_NOT_FOUND)
}
- return certs, nil
+ return names, nil
}
diff --git a/version/wintrust/zsyscall_windows.go b/version/wintrust/zsyscall_windows.go
index 8aa315c0..7c742938 100644
--- a/version/wintrust/zsyscall_windows.go
+++ b/version/wintrust/zsyscall_windows.go
@@ -40,8 +40,9 @@ var (
modwintrust = windows.NewLazySystemDLL("wintrust.dll")
modcrypt32 = windows.NewLazySystemDLL("crypt32.dll")
- procWinVerifyTrust = modwintrust.NewProc("WinVerifyTrust")
- procCryptQueryObject = modcrypt32.NewProc("CryptQueryObject")
+ procWinVerifyTrust = modwintrust.NewProc("WinVerifyTrust")
+ procCryptQueryObject = modcrypt32.NewProc("CryptQueryObject")
+ procCertGetNameStringW = modcrypt32.NewProc("CertGetNameStringW")
)
func WinVerifyTrust(hWnd windows.Handle, actionId *windows.GUID, data *WinTrustData) (err error) {
@@ -56,7 +57,7 @@ func WinVerifyTrust(hWnd windows.Handle, actionId *windows.GUID, data *WinTrustD
return
}
-func CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) {
+func cryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) {
r1, _, e1 := syscall.Syscall12(procCryptQueryObject.Addr(), 11, uintptr(objectType), uintptr(object), uintptr(expectedContentTypeFlags), uintptr(expectedFormatTypeFlags), uintptr(flags), uintptr(unsafe.Pointer(msgAndCertEncodingType)), uintptr(unsafe.Pointer(contentType)), uintptr(unsafe.Pointer(formatType)), uintptr(unsafe.Pointer(certStore)), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(context)), 0)
if r1 == 0 {
if e1 != 0 {
@@ -67,3 +68,9 @@ func CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlag
}
return
}
+
+func certGetNameString(certContext *windows.CertContext, nameType uint32, flags uint32, typePara uintptr, name *uint16, size uint32) (chars uint32) {
+ r0, _, _ := syscall.Syscall6(procCertGetNameStringW.Addr(), 6, uintptr(unsafe.Pointer(certContext)), uintptr(nameType), uintptr(flags), uintptr(typePara), uintptr(unsafe.Pointer(name)), uintptr(size))
+ chars = uint32(r0)
+ return
+}