aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service
diff options
context:
space:
mode:
Diffstat (limited to 'service')
-rw-r--r--service/errors.go3
-rw-r--r--service/ipc_client.go78
-rw-r--r--service/ipc_server.go58
-rw-r--r--service/securityapi.go157
-rw-r--r--service/service_manager.go182
-rw-r--r--service/updatestate.go56
6 files changed, 386 insertions, 148 deletions
diff --git a/service/errors.go b/service/errors.go
index ecc0283a..eb31bb4e 100644
--- a/service/errors.go
+++ b/service/errors.go
@@ -26,7 +26,6 @@ const (
ErrorBindSocketsToDefaultRoutes
ErrorSetNetConfig
ErrorDetermineExecutablePath
- ErrorFindAdministratorsSID
ErrorCreateSecurityDescriptor
ErrorOpenNULFile
ErrorTrackTunnels
@@ -60,8 +59,6 @@ func (e Error) Error() string {
return "Unable to bind sockets to default route"
case ErrorSetNetConfig:
return "Unable to set interface addresses, routes, dns, and/or adapter settings"
- case ErrorFindAdministratorsSID:
- return "Unable to find Administrators SID"
case ErrorCreateSecurityDescriptor:
return "Unable to determine security descriptor"
case ErrorOpenNULFile:
diff --git a/service/ipc_client.go b/service/ipc_client.go
index 41e71f22..adfd456c 100644
--- a/service/ipc_client.go
+++ b/service/ipc_client.go
@@ -9,6 +9,7 @@ import (
"encoding/gob"
"errors"
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
"net/rpc"
"os"
)
@@ -33,6 +34,8 @@ const (
TunnelChangeNotificationType NotificationType = iota
TunnelsChangeNotificationType
ManagerStoppingNotificationType
+ UpdateFoundNotificationType
+ UpdateProgressNotificationType
)
var rpcClient *rpc.Client
@@ -55,6 +58,18 @@ type ManagerStoppingCallback struct {
var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool)
+type UpdateFoundCallback struct {
+ cb func(updateState UpdateState)
+}
+
+var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool)
+
+type UpdateProgressCallback struct {
+ cb func(dp updater.DownloadProgress)
+}
+
+var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
+
func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
rpcClient = rpc.NewClient(&pipeRWC{reader, writer})
go func() {
@@ -106,6 +121,44 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
for cb := range managerStoppingCallbacks {
cb.cb()
}
+ case UpdateFoundNotificationType:
+ var state UpdateState
+ err = decoder.Decode(&state)
+ if err != nil {
+ continue
+ }
+ for cb := range updateFoundCallbacks {
+ cb.cb(state)
+ }
+ case UpdateProgressNotificationType:
+ var dp updater.DownloadProgress
+ err = decoder.Decode(&dp.Activity)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesDownloaded)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesTotal)
+ if err != nil {
+ continue
+ }
+ var errStr string
+ err = decoder.Decode(&errStr)
+ if err != nil {
+ continue
+ }
+ if len(errStr) > 0 {
+ dp.Error = errors.New(errStr)
+ }
+ err = decoder.Decode(&dp.Complete)
+ if err != nil {
+ continue
+ }
+ for cb := range updateProgressCallbacks {
+ cb.cb(dp)
+ }
}
}
}()
@@ -176,6 +229,15 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
}
+func IPCClientUpdateState() (UpdateState, error) {
+ var state UpdateState
+ return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state)
+}
+
+func IPCClientUpdate() error {
+ return rpcClient.Call("ManagerService.Update", uintptr(0), nil)
+}
+
func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback {
s := &TunnelChangeCallback{cb}
tunnelChangeCallbacks[s] = true
@@ -200,3 +262,19 @@ func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
func (cb *ManagerStoppingCallback) Unregister() {
delete(managerStoppingCallbacks, cb)
}
+func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
+ s := &UpdateFoundCallback{cb}
+ updateFoundCallbacks[s] = true
+ return s
+}
+func (cb *UpdateFoundCallback) Unregister() {
+ delete(updateFoundCallbacks, cb)
+}
+func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
+ s := &UpdateProgressCallback{cb}
+ updateProgressCallbacks[s] = true
+ return s
+}
+func (cb *UpdateProgressCallback) Unregister() {
+ delete(updateProgressCallbacks, cb)
+}
diff --git a/service/ipc_server.go b/service/ipc_server.go
index 9ccea8ef..becad3ee 100644
--- a/service/ipc_server.go
+++ b/service/ipc_server.go
@@ -10,8 +10,10 @@ import (
"encoding/gob"
"fmt"
"github.com/Microsoft/go-winio"
+ "golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
"io/ioutil"
"log"
"net/rpc"
@@ -27,8 +29,14 @@ var managerServicesLock sync.RWMutex
var haveQuit uint32
var quitManagersChan = make(chan struct{}, 1)
+type UserTokenInfo struct {
+ elevatedToken windows.Token
+ elevatedEnvironment []string
+}
+
type ManagerService struct {
- events *os.File
+ events *os.File
+ userTokenInfo *UserTokenInfo
}
func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error {
@@ -115,7 +123,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
return InstallTunnel(path)
}
-func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error {
err := UninstallTunnel(tunnelName)
if err == syscall.Errno(serviceDOES_NOT_EXIST) {
_, notExistsError := conf.LoadFromName(tunnelName)
@@ -126,7 +134,7 @@ func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error {
return err
}
-func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error {
serviceName, err := ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
@@ -146,7 +154,7 @@ func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error {
}
}
-func (s *ManagerService) Delete(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error {
err := s.Stop(tunnelName, nil)
if err != nil {
return err
@@ -189,7 +197,7 @@ func (s *ManagerService) State(tunnelName string, state *TunnelState) error {
return nil
}
-func (s *ManagerService) GlobalState(unused uintptr, state *TunnelState) error {
+func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error {
*state = trackedTunnelsGlobalState()
return nil
}
@@ -205,7 +213,7 @@ func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error
//TODO: handle already running and existing situation
}
-func (s *ManagerService) Tunnels(unused uintptr, tunnels *[]Tunnel) error {
+func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error {
names, err := conf.ListConfigNames()
if err != nil {
return err
@@ -244,8 +252,30 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
return nil
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error {
- service := &ManagerService{events: events}
+func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error {
+ *state = updateState
+ return nil
+}
+
+func (s *ManagerService) Update(_ uintptr, _ *uintptr) error {
+ progress := updater.DownloadVerifyAndExecute(uintptr(s.userTokenInfo.elevatedToken), s.userTokenInfo.elevatedEnvironment)
+ go func() {
+ for {
+ dp := <-progress
+ IPCServerNotifyUpdateProgress(dp)
+ if dp.Complete || dp.Error != nil {
+ return
+ }
+ }
+ }()
+ return nil
+}
+
+func IPCServerListen(reader *os.File, writer *os.File, events *os.File, userTokenInfo *UserTokenInfo) error {
+ service := &ManagerService{
+ events: events,
+ userTokenInfo: userTokenInfo,
+ }
server := rpc.NewServer()
err := server.Register(service)
@@ -304,6 +334,18 @@ func IPCServerNotifyTunnelsChange() {
notifyAll(TunnelsChangeNotificationType)
}
+func IPCServerNotifyUpdateFound(state UpdateState) {
+ notifyAll(UpdateFoundNotificationType, state)
+}
+
+func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
+ if dp.Error == nil {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete)
+ } else {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete)
+ }
+}
+
func IPCServerNotifyManagerStopping() {
notifyAll(ManagerStoppingNotificationType)
time.Sleep(time.Millisecond * 200)
diff --git a/service/securityapi.go b/service/securityapi.go
new file mode 100644
index 00000000..6c5f7844
--- /dev/null
+++ b/service/securityapi.go
@@ -0,0 +1,157 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package service
+
+import (
+ "errors"
+ "golang.org/x/sys/windows"
+ "syscall"
+ "unicode/utf16"
+ "unsafe"
+)
+
+const (
+ wtsSessionLogon uint32 = 5
+ wtsSessionLogoff uint32 = 6
+)
+
+type wtsState int
+
+const (
+ wtsActive wtsState = iota
+ wtsConnected
+ wtsConnectQuery
+ wtsShadow
+ wtsDisconnected
+ wtsIdle
+ wtsListen
+ wtsReset
+ wtsDown
+ wtsInit
+)
+
+type wtsSessionNotification struct {
+ size uint32
+ sessionID uint32
+}
+
+type wtsSessionInfo struct {
+ sessionID uint32
+ windowStationName *uint16
+ state wtsState
+}
+
+//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken
+//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW
+//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory
+
+const (
+ SE_KERNEL_OBJECT = 6
+
+ DACL_SECURITY_INFORMATION = 4
+ ATTRIBUTE_SECURITY_INFORMATION = 16
+)
+
+//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo
+//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+
+//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock
+//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock
+
+func userEnviron(token windows.Token) (env []string, err error) {
+ var block uintptr
+ err = createEnvironmentBlock(&block, token, false)
+ if err != nil {
+ return
+ }
+ offset := uintptr(0)
+ for {
+ entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:]
+ for i, v := range entry {
+ if v == 0 {
+ entry = entry[:i]
+ break
+ }
+ }
+ if len(entry) == 0 {
+ break
+ }
+ env = append(env, string(utf16.Decode(entry)))
+ offset += 2 * (uintptr(len(entry)) + 1)
+ }
+ destroyEnvironmentBlock(block)
+ return
+}
+
+func tokenIsElevated(token windows.Token) bool {
+ var isElevated uint32
+ var outLen uint32
+ err := windows.GetTokenInformation(token, windows.TokenElevation, (*byte)(unsafe.Pointer(&isElevated)), uint32(unsafe.Sizeof(isElevated)), &outLen)
+ if err != nil {
+ return false
+ }
+ return outLen == uint32(unsafe.Sizeof(isElevated)) && isElevated != 0
+
+}
+
+func getElevatedToken(token windows.Token) (windows.Token, error) {
+ if tokenIsElevated(token) {
+ return token, nil
+ }
+ var linkedToken windows.Token
+ var outLen uint32
+ err := windows.GetTokenInformation(token, windows.TokenLinkedToken, (*byte)(unsafe.Pointer(&linkedToken)), uint32(unsafe.Sizeof(linkedToken)), &outLen)
+ if err != nil {
+ return windows.Token(0), err
+ }
+ if tokenIsElevated(linkedToken) {
+ return linkedToken, nil
+ }
+ linkedToken.Close()
+ return windows.Token(0), errors.New("the linked token is not elevated")
+}
+
+func tokenIsMemberOfBuiltInAdministrator(token windows.Token) bool {
+ //TODO: SECURITY CRITICIAL!
+ //TODO: Isn't it better to use an impersonation token or userToken.IsMember instead?
+ adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
+ if err != nil {
+ return false
+ }
+ gs, err := token.GetTokenGroups()
+ if err != nil {
+ return false
+ }
+ p := unsafe.Pointer(&gs.Groups[0])
+ groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount]
+ isAdmin := false
+ for _, g := range groups {
+ if windows.EqualSid(g.Sid, adminSid) {
+ isAdmin = true
+ break
+ }
+ }
+ return isAdmin
+}
+
+func getCurrentSecurityAttributes() (*syscall.SecurityAttributes, error) {
+ currentProcess, err := windows.GetCurrentProcess()
+ if err != nil {
+ return nil, err
+ }
+ securityAttributes := &syscall.SecurityAttributes{}
+ err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor)
+ if err != nil {
+ return nil, err
+ }
+ windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
+ securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor)
+ if securityAttributes.Length == 0 {
+ windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
+ return nil, err
+ }
+ return securityAttributes, nil
+} \ No newline at end of file
diff --git a/service/service_manager.go b/service/service_manager.go
index 715b4257..e8818ae4 100644
--- a/service/service_manager.go
+++ b/service/service_manager.go
@@ -16,82 +16,9 @@ import (
"runtime/debug"
"sync"
"syscall"
- "unicode/utf16"
"unsafe"
)
-const (
- wtsSessionLogon uint32 = 5
- wtsSessionLogoff uint32 = 6
-)
-
-type wtsState int
-
-const (
- wtsActive wtsState = iota
- wtsConnected
- wtsConnectQuery
- wtsShadow
- wtsDisconnected
- wtsIdle
- wtsListen
- wtsReset
- wtsDown
- wtsInit
-)
-
-type wtsSessionNotification struct {
- size uint32
- sessionID uint32
-}
-
-type wtsSessionInfo struct {
- sessionID uint32
- windowStationName *uint16
- state wtsState
-}
-
-//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken
-//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW
-//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory
-
-const (
- SE_KERNEL_OBJECT = 6
- DACL_SECURITY_INFORMATION = 4
- ATTRIBUTE_SECURITY_INFORMATION = 16
-)
-
-//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo
-//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
-
-//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock
-//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock
-
-func userEnviron(token windows.Token) (env []string, err error) {
- var block uintptr
- err = createEnvironmentBlock(&block, token, false)
- if err != nil {
- return
- }
- offset := uintptr(0)
- for {
- entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:]
- for i, v := range entry {
- if v == 0 {
- entry = entry[:i]
- break
- }
- }
- if len(entry) == 0 {
- break
- }
- env = append(env, string(utf16.Decode(entry)))
- offset += 2 * (uintptr(len(entry)) + 1)
- }
- destroyEnvironmentBlock(block)
- return
-}
-
type managerService struct{}
func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
@@ -126,29 +53,12 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
serviceError = ErrorDetermineExecutablePath
return
}
-
- adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
- if err != nil {
- serviceError = ErrorFindAdministratorsSID
- return
- }
-
- currentProcess, err := windows.GetCurrentProcess()
- if err != nil {
- panic(err)
- }
- var securityAttributes syscall.SecurityAttributes
- err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor)
+ securityAttributes, err := getCurrentSecurityAttributes()
if err != nil {
serviceError = ErrorCreateSecurityDescriptor
return
}
defer windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
- securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor)
- if securityAttributes.Length == 0 {
- serviceError = ErrorCreateSecurityDescriptor
- return
- }
devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
if err != nil {
@@ -172,49 +82,51 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
startProcess = func(session uint32) {
defer runtime.UnlockOSThread()
+
+ var userToken windows.Token
+ err := wtsQueryUserToken(session, &userToken)
+ if err != nil {
+ return
+ }
+ defer userToken.Close()
+ if !tokenIsMemberOfBuiltInAdministrator(userToken) {
+ return
+ }
+ user, err := userToken.GetTokenUser()
+ if err != nil {
+ log.Printf("Unable to lookup user from token: %v", err)
+ return
+ }
+ username, domain, accType, err := user.User.Sid.LookupAccount("")
+ if err != nil {
+ log.Printf("Unable to lookup username from sid: %v", err)
+ return
+ }
+ if accType != windows.SidTypeUser {
+ return
+ }
+ env, err := userEnviron(userToken)
+ if err != nil {
+ log.Printf("Unable to determine user environment: %v", err)
+ return
+ }
+ userTokenInfo := &UserTokenInfo{}
+ userTokenInfo.elevatedToken, err = getElevatedToken(userToken)
+ if err != nil {
+ log.Printf("Unable to elevate token: %v", err)
+ }
+ if userTokenInfo.elevatedToken != userToken {
+ defer userTokenInfo.elevatedToken.Close()
+ }
+ userTokenInfo.elevatedEnvironment, err = userEnviron(userTokenInfo.elevatedToken)
+ if err != nil {
+ log.Printf("Unable to determine elevated environment: %v", err)
+ return
+ }
for {
if stoppingManager {
return
}
- var userToken windows.Token
- err := wtsQueryUserToken(session, &userToken)
- if err != nil {
- return
- }
-
- //TODO: SECURITY CRITICIAL!
- //TODO: Isn't it better to use an impersonation token and userToken.IsMember instead?
- gs, err := userToken.GetTokenGroups()
- if err != nil {
- log.Printf("Unable to lookup user groups from token: %v", err)
- return
- }
- p := unsafe.Pointer(&gs.Groups[0])
- groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount]
- isAdmin := false
- for _, g := range groups {
- if windows.EqualSid(g.Sid, adminSid) {
- isAdmin = true
- break
- }
- }
- if !isAdmin {
- return
- }
-
- user, err := userToken.GetTokenUser()
- if err != nil {
- log.Printf("Unable to lookup user from token: %v", err)
- return
- }
- username, domain, accType, err := user.User.Sid.LookupAccount("")
- if err != nil {
- log.Printf("Unable to lookup username from sid: %v", err)
- return
- }
- if accType != windows.SidTypeUser {
- return
- }
//TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that
// might be running in parallel Go routines. But the Go runtime is strange and who knows what's really
@@ -226,7 +138,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
ourEvents, theirEvents, theirEventStr, err := inheritableEvents()
- err = IPCServerListen(ourReader, ourWriter, ourEvents)
+ err = IPCServerListen(ourReader, ourWriter, ourEvents, userTokenInfo)
if err != nil {
log.Printf("Unable to listen on IPC pipes: %v", err)
return
@@ -237,12 +149,6 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
- env, err := userEnviron(userToken)
- if err != nil {
- log.Printf("Unable to determine user environment: %v", err)
- return
- }
-
log.Printf("Starting UI process for user: '%s@%s'", username, domain)
attr := &os.ProcAttr{
Sys: &syscall.SysProcAttr{
@@ -286,6 +192,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
}
}
+ go checkForUpdates()
+
var sessionsPointer *wtsSessionInfo
var count uint32
err = wtsEnumerateSessions(0, 0, 1, &sessionsPointer, &count)
diff --git a/service/updatestate.go b/service/updatestate.go
new file mode 100644
index 00000000..c046edb2
--- /dev/null
+++ b/service/updatestate.go
@@ -0,0 +1,56 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package service
+
+import (
+ "golang.zx2c4.com/wireguard/windows/updater"
+ "golang.zx2c4.com/wireguard/windows/version"
+ "log"
+ "time"
+)
+
+type UpdateState uint32
+
+const (
+ UpdateStateUnknown UpdateState = iota
+ UpdateStateFoundUpdate
+ UpdateStateUpdatesDisabledUnofficialBuild
+)
+
+var updateState = UpdateStateUnknown
+
+func checkForUpdates() {
+ if !version.IsRunningOfficialVersion() {
+ log.Println("Build is not official, so updates are disabled")
+ updateState = UpdateStateUpdatesDisabledUnofficialBuild
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+
+ time.Sleep(time.Second * 10)
+
+ first := true
+ for {
+ update, err := updater.CheckForUpdate()
+ if err == nil && update != nil {
+ log.Println("An update is available")
+ updateState = UpdateStateFoundUpdate
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+ if err != nil {
+ log.Printf("Update checker: %v", err)
+ if first {
+ time.Sleep(time.Minute * 4)
+ first = false
+ } else {
+ time.Sleep(time.Minute * 25)
+ }
+ } else {
+ time.Sleep(time.Hour)
+ }
+ }
+}