diff options
-rw-r--r-- | version/official_windows.go | 6 | ||||
-rw-r--r-- | version/wintrust/certificate_test.go | 28 | ||||
-rw-r--r-- | version/wintrust/certificate_windows.go | 39 | ||||
-rw-r--r-- | version/wintrust/zsyscall_windows.go | 13 |
4 files changed, 65 insertions, 21 deletions
diff --git a/version/official_windows.go b/version/official_windows.go index b0f62250..d9f041f6 100644 --- a/version/official_windows.go +++ b/version/official_windows.go @@ -67,12 +67,12 @@ func IsRunningOfficialVersion() bool { // This below tests is easily circumvented. False certificates can be appended, and just checking the // common name is not very good. But that's okay, as this isn't security related. - certs, err := wintrust.ExtractCertificates(path) + names, err := wintrust.ExtractCertificateNames(path) if err != nil { return false } - for _, cert := range certs { - if cert.Subject.CommonName == officialCommonName { + for _, name := range names { + if name == officialCommonName { return true } } diff --git a/version/wintrust/certificate_test.go b/version/wintrust/certificate_test.go new file mode 100644 index 00000000..19007351 --- /dev/null +++ b/version/wintrust/certificate_test.go @@ -0,0 +1,28 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wintrust + +import ( + "fmt" + "path/filepath" + "testing" + + "golang.org/x/sys/windows" +) + +func TestExtractCertificateNames(t *testing.T) { + system32, err := windows.GetSystemDirectory() + if err != nil { + t.Fatal(err) + } + names, err := ExtractCertificateNames(filepath.Join(system32, "ntoskrnl.exe")) + if err != nil { + t.Fatal(err) + } + for i, name := range names { + fmt.Printf("%d: %s\n", i, name) + } +} diff --git a/version/wintrust/certificate_windows.go b/version/wintrust/certificate_windows.go index cf254e4f..8c933f11 100644 --- a/version/wintrust/certificate_windows.go +++ b/version/wintrust/certificate_windows.go @@ -6,7 +6,6 @@ package wintrust import ( - "crypto/x509" "syscall" "unsafe" @@ -14,26 +13,28 @@ import ( ) const ( - CERT_QUERY_OBJECT_FILE = 1 - CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED = 1024 - CERT_QUERY_FORMAT_FLAG_ALL = 14 + _CERT_QUERY_OBJECT_FILE = 1 + _CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED = 1024 + _CERT_QUERY_FORMAT_FLAG_ALL = 14 + _CERT_NAME_SIMPLE_DISPLAY_TYPE = 4 ) -//sys CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) = crypt32.CryptQueryObject +//sys cryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) = crypt32.CryptQueryObject +//sys certGetNameString(certContext *windows.CertContext, nameType uint32, flags uint32, typePara uintptr, name *uint16, size uint32) (chars uint32) = crypt32.CertGetNameStringW -func ExtractCertificates(path string) ([]x509.Certificate, error) { +func ExtractCertificateNames(path string) ([]string, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return nil, err } var certStore windows.Handle - err = CryptQueryObject(CERT_QUERY_OBJECT_FILE, uintptr(unsafe.Pointer(path16)), CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil) + err = cryptQueryObject(_CERT_QUERY_OBJECT_FILE, uintptr(unsafe.Pointer(path16)), _CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, _CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil) if err != nil { return nil, err } defer windows.CertCloseStore(certStore, 0) - var certs []x509.Certificate var cert *windows.CertContext + var names []string for { cert, err = windows.CertEnumCertificatesInStore(certStore, cert) if err != nil { @@ -47,13 +48,21 @@ func ExtractCertificates(path string) ([]x509.Certificate, error) { if cert == nil { break } - buf := make([]byte, cert.Length) - copy(buf, (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]) - if c, err := x509.ParseCertificate(buf); err == nil { - certs = append(certs, *c) - } else { - return nil, err + nameLen := certGetNameString(cert, _CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, 0, nil, 0) + if nameLen == 0 { + continue + } + name16 := make([]uint16, nameLen) + if certGetNameString(cert, _CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, 0, &name16[0], nameLen) != nameLen { + continue } + if name16[0] == 0 { + continue + } + names = append(names, windows.UTF16ToString(name16)) + } + if names == nil { + return nil, syscall.Errno(windows.CRYPT_E_NOT_FOUND) } - return certs, nil + return names, nil } diff --git a/version/wintrust/zsyscall_windows.go b/version/wintrust/zsyscall_windows.go index 8aa315c0..7c742938 100644 --- a/version/wintrust/zsyscall_windows.go +++ b/version/wintrust/zsyscall_windows.go @@ -40,8 +40,9 @@ var ( modwintrust = windows.NewLazySystemDLL("wintrust.dll") modcrypt32 = windows.NewLazySystemDLL("crypt32.dll") - procWinVerifyTrust = modwintrust.NewProc("WinVerifyTrust") - procCryptQueryObject = modcrypt32.NewProc("CryptQueryObject") + procWinVerifyTrust = modwintrust.NewProc("WinVerifyTrust") + procCryptQueryObject = modcrypt32.NewProc("CryptQueryObject") + procCertGetNameStringW = modcrypt32.NewProc("CertGetNameStringW") ) func WinVerifyTrust(hWnd windows.Handle, actionId *windows.GUID, data *WinTrustData) (err error) { @@ -56,7 +57,7 @@ func WinVerifyTrust(hWnd windows.Handle, actionId *windows.GUID, data *WinTrustD return } -func CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) { +func cryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlags uint32, expectedFormatTypeFlags uint32, flags uint32, msgAndCertEncodingType *uint32, contentType *uint32, formatType *uint32, certStore *windows.Handle, msg *windows.Handle, context *uintptr) (err error) { r1, _, e1 := syscall.Syscall12(procCryptQueryObject.Addr(), 11, uintptr(objectType), uintptr(object), uintptr(expectedContentTypeFlags), uintptr(expectedFormatTypeFlags), uintptr(flags), uintptr(unsafe.Pointer(msgAndCertEncodingType)), uintptr(unsafe.Pointer(contentType)), uintptr(unsafe.Pointer(formatType)), uintptr(unsafe.Pointer(certStore)), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(context)), 0) if r1 == 0 { if e1 != 0 { @@ -67,3 +68,9 @@ func CryptQueryObject(objectType uint32, object uintptr, expectedContentTypeFlag } return } + +func certGetNameString(certContext *windows.CertContext, nameType uint32, flags uint32, typePara uintptr, name *uint16, size uint32) (chars uint32) { + r0, _, _ := syscall.Syscall6(procCertGetNameStringW.Addr(), 6, uintptr(unsafe.Pointer(certContext)), uintptr(nameType), uintptr(flags), uintptr(typePara), uintptr(unsafe.Pointer(name)), uintptr(size)) + chars = uint32(r0) + return +} |