diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-07-12 15:53:10 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-08-02 19:10:58 +0200 |
commit | 5409c45a10dc7a045197bc4105c6a7bd5d29283f (patch) | |
tree | e64f9e7e09a4f3d965659413487781f452800256 /tunnel | |
parent | version: bump (diff) | |
download | wireguard-windows-5409c45a10dc7a045197bc4105c6a7bd5d29283f.tar.xz wireguard-windows-5409c45a10dc7a045197bc4105c6a7bd5d29283f.zip |
driver: introduce new module for talking with kernel driver
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tunnel')
-rw-r--r-- | tunnel/addressconfig.go | 14 | ||||
-rw-r--r-- | tunnel/defaultroutemonitor.go | 15 | ||||
-rw-r--r-- | tunnel/interfacewatcher.go | 45 | ||||
-rw-r--r-- | tunnel/mtumonitor.go | 113 | ||||
-rw-r--r-- | tunnel/service.go | 175 | ||||
-rw-r--r-- | tunnel/winipcfg/types.go | 31 |
6 files changed, 305 insertions, 88 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 0dec95d0..fba7d770 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -12,8 +12,6 @@ import ( "sort" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -57,9 +55,7 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add } } -func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { - luid := winipcfg.LUID(tun.LUID()) - +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID, clamper mtuClamper) error { estimatedRouteCount := 0 for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) @@ -151,7 +147,9 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t } if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) - tun.ForceMTU(int(ipif.NLMTU)) + if clamper != nil { + clamper.ForceMTU(int(ipif.NLMTU)) + } } if family == windows.AF_INET { if foundDefault4 { @@ -174,7 +172,7 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) } -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { +func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error { doNotRestrict := true if len(conf.Peers) == 1 && !conf.Interface.TableOff { nextallowedip: @@ -191,5 +189,5 @@ func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { } } log.Println("Enabling firewall rules") - return firewall.EnableFirewall(tun.LUID(), doNotRestrict, conf.Interface.DNS) + return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS) } diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index aa0db675..ac4241c9 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -11,9 +11,7 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) @@ -61,14 +59,17 @@ func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInte return nil } -func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) { +type mtuClamper interface { + ForceMTU(mtu int) +} + +func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, clamper mtuClamper, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) { var minMTU uint32 if family == windows.AF_INET { minMTU = 576 } else if family == windows.AF_INET6 { minMTU = 1280 } - ourLUID := winipcfg.LUID(tun.LUID()) lastLUID := winipcfg.LUID(0) lastIndex := ^uint32(0) lastMTU := uint32(0) @@ -103,7 +104,11 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT if err != nil { return err } - tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad. + + // Having one MTU for both v4 and v6 kind of breaks the Windows model, so right now this just gets the + // second one which looks bad. However, internally, it doesn't seem like the Windows stack differentiates + // anyway, so it's probably fine. + clamper.ForceMTU(int(iface.NLMTU)) lastMTU = mtu } return nil diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index e12e5929..32132e93 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -12,8 +12,6 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" @@ -31,9 +29,10 @@ type interfaceWatcherEvent struct { type interfaceWatcher struct { errors chan interfaceWatcherError - binder conn.BindSocketToInterface - conf *conf.Config - tun *tun.NativeTun + binder conn.BindSocketToInterface + clamper mtuClamper + conf *conf.Config + luid winipcfg.LUID setupMutex sync.Mutex interfaceChangeCallback winipcfg.ChangeCallback @@ -100,15 +99,24 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return + if iw.binder != nil && iw.clamper != nil { + log.Printf("Monitoring default %s routes", ipversion) + *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} + return + } + } else if iw.conf.Interface.MTU == 0 { + log.Printf("Monitoring MTU of default %s routes", ipversion) + *changeCallbacks, err = monitorMTU(family, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err} + return + } } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.tun) + err = configureInterface(family, iw.conf, iw.luid, iw.clamper) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return @@ -127,11 +135,11 @@ func watchInterface() (*interfaceWatcher, error) { if notificationType != winipcfg.MibAddInstance { return } - if iw.tun == nil { + if iw.luid == 0 { iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) return } - if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + if iface.InterfaceLUID != iw.luid { return } iw.setup(iface.Family) @@ -142,13 +150,13 @@ func watchInterface() (*interfaceWatcher, error) { return iw, nil } -func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) { +func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() - iw.binder, iw.conf, iw.tun = binder, conf, tun + iw.binder, iw.clamper, iw.conf, iw.luid = binder, clamper, conf, luid for _, event := range iw.storedEvents { - if event.luid == winipcfg.LUID(iw.tun.LUID()) { + if event.luid == luid { iw.setup(event.family) } } @@ -160,7 +168,7 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks4 := iw.changeCallbacks4 changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback - tun := iw.tun + luid := iw.luid iw.setupMutex.Unlock() if interfaceChangeCallback != nil { @@ -186,10 +194,9 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() - if tun != nil && iw.tun == tun { + if luid != 0 && iw.luid == luid { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) luid.FlushDNS(windows.AF_INET) diff --git a/tunnel/mtumonitor.go b/tunnel/mtumonitor.go new file mode 100644 index 00000000..766ca1b8 --- /dev/null +++ b/tunnel/mtumonitor.go @@ -0,0 +1,113 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error { + r, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return err + } + lowestMetric := ^uint32(0) + index := uint32(0) + luid := winipcfg.LUID(0) + for i := range r { + if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID { + continue + } + ifrow, err := r[i].InterfaceLUID.Interface() + if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { + continue + } + + iface, err := r[i].InterfaceLUID.IPInterface(family) + if err != nil { + continue + } + + if r[i].Metric+iface.Metric < lowestMetric { + lowestMetric = r[i].Metric + iface.Metric + index = r[i].InterfaceIndex + luid = r[i].InterfaceLUID + } + } + if luid == *lastLUID && index == *lastIndex { + return nil + } + *lastLUID = luid + *lastIndex = index + return nil +} + +func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) { + var minMTU uint32 + if family == windows.AF_INET { + minMTU = 576 + } else if family == windows.AF_INET6 { + minMTU = 1280 + } + lastLUID := winipcfg.LUID(0) + lastIndex := ^uint32(0) + lastMTU := uint32(0) + doIt := func() error { + err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex) + if err != nil { + return err + } + mtu := uint32(0) + if lastLUID != 0 { + iface, err := lastLUID.Interface() + if err != nil { + return err + } + if iface.MTU > 0 { + mtu = iface.MTU + } + } + if mtu > 0 && lastMTU != mtu { + iface, err := ourLUID.IPInterface(family) + if err != nil { + return err + } + iface.NLMTU = mtu - 80 + if iface.NLMTU < minMTU { + iface.NLMTU = minMTU + } + err = iface.Set() + if err != nil { + return err + } + lastMTU = mtu + } + return nil + } + err := doIt() + if err != nil { + return nil, err + } + cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + if route != nil && route.DestinationPrefix.PrefixLength == 0 { + doIt() + } + }) + if err != nil { + return nil, err + } + cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + if notificationType == winipcfg.MibParameterNotification { + doIt() + } + }) + if err != nil { + cbr.Unregister() + return nil, err + } + return []winipcfg.ChangeCallback{cbr, cbi}, nil +} diff --git a/tunnel/service.go b/tunnel/service.go index 63cd243f..a595994c 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -21,11 +21,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/version" ) @@ -40,6 +41,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, var uapi net.Listener var watcher *interfaceWatcher var nativeTun *tun.NativeTun + var wintun tun.Device + var adapter *driver.Adapter + var luid winipcfg.LUID var config *conf.Config var err error serviceError := services.ErrorSuccess @@ -127,10 +131,10 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, if m, err := mgr.Connect(); err == nil { if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the Wintun installation will block forever, because - * installing a Wintun device starts a service too. Apparently at boot time, Windows - * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce - * that we're running before starting additional services. + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. */ log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) changes <- svc.Status{State: svc.Running} @@ -146,34 +150,81 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Resolving DNS names") - uapiConf, err := config.ToUAPI() + err = config.ResolveEndpoints() if err != nil { serviceError = services.ErrorDNSLookup return } - log.Println("Creating Wintun interface") - var wintun tun.Device - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + log.Println("Creating network adapter") + if conf.AdminBool("ExperimentalKernelDriver") { + // Does an adapter with this name already exist? + adapter, err = driver.DefaultPool.OpenAdapter(config.Name) + if err == nil { + // If so, we delete it, in case it has weird residual configuration. + _, err = adapter.Delete() + if err != nil { + err = fmt.Errorf("Error deleting already existing adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } } - wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) - if err == nil || windows.DurationSinceBoot() > time.Minute*4 { - break + for i := 0; i < 5; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + } + var rebootRequired bool + adapter, rebootRequired, err = driver.DefaultPool.CreateAdapter(config.Name, deterministicGUID(config)) + if err == nil || windows.DurationSinceBoot() > time.Minute*4 { + if rebootRequired { + log.Println("Windows indicated a reboot is required.") + } + break + } + } + if err != nil { + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } + defer adapter.Delete() + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) + } else { + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } - } - if err != nil { - serviceError = services.ErrorCreateWintun - return - } - nativeTun = wintun.(*tun.NativeTun) - wintunVersion, err := nativeTun.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine Wintun version: %v", err) } else { - log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff) + for i := 0; i < 5; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + } + wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) + if err == nil || windows.DurationSinceBoot() > time.Minute*4 { + break + } + } + if err != nil { + serviceError = services.ErrorCreateNetworkAdapter + return + } + nativeTun = wintun.(*tun.NativeTun) + luid = winipcfg.LUID(nativeTun.LUID()) + driverVersion, err := nativeTun.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) + } else { + log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } } err = runScriptCommand(config.Interface.PreUp, config.Name) @@ -182,7 +233,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - err = enableFirewall(config, nativeTun) + err = enableFirewall(config, luid) if err != nil { serviceError = services.ErrorFirewall return @@ -195,37 +246,51 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - log.Println("Creating interface instance") - bind := conn.NewDefaultBind() - dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) + if nativeTun != nil { + log.Println("Creating interface instance") + bind := conn.NewDefaultBind() + dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(config.Name) - if err != nil { - serviceError = services.ErrorUAPIListen - return - } - err = dev.IpcSet(uapiConf) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } + log.Println("Setting interface configuration") + uapi, err = ipc.UAPIListen(config.Name) + if err != nil { + serviceError = services.ErrorUAPIListen + return + } + err = dev.IpcSet(config.ToUAPI()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig + return + } - log.Println("Bringing peers up") - dev.Up() + log.Println("Bringing peers up") + dev.Up() - watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun) + var clamper mtuClamper + clamper = nativeTun + watcher.Configure(bind.(conn.BindSocketToInterface), clamper, config, luid) - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue + log.Println("Listening for UAPI requests") + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + continue + } + go dev.IpcHandle(conn) } - go dev.IpcHandle(conn) + }() + } else { + err = adapter.SetConfiguration(config.ToDriverConfiguration()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig } - }() + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + serviceError = services.ErrorDeviceBringUp + } + watcher.Configure(nil, nil, config, luid) + } err = runScriptCommand(config.Interface.PostUp, config.Name) if err != nil { @@ -236,6 +301,12 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} log.Println("Startup complete") + var devWaitChan chan struct{} + if dev != nil { + devWaitChan = dev.Wait() + } else { + devWaitChan = make(chan struct{}) + } for { select { case c := <-r: @@ -247,7 +318,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-dev.Wait(): + case <-devWaitChan: return case e := <-watcher.errors: serviceError, err = e.serviceError, e.err diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 4dc52d8b..b06f05dd 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -6,6 +6,7 @@ package winipcfg import ( + "encoding/binary" "net" "unsafe" @@ -734,6 +735,16 @@ type RawSockaddrInet struct { data [26]byte } +func ntohs(i uint16) uint16 { + return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:]) +} + +func htons(i uint16) uint16 { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, i) + return *(*uint16)(unsafe.Pointer(&b[0])) +} + // SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port. // All other members of the structure are set to zero. func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { @@ -741,7 +752,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) addr4.Family = windows.AF_INET copy(addr4.Addr[:], v4) - addr4.Port = windows.Ntohs(port) + addr4.Port = htons(port) for i := 0; i < 8; i++ { addr4.Zero[i] = 0 } @@ -751,7 +762,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { if v6 := ip.To16(); v6 != nil { addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) addr6.Family = windows.AF_INET6 - addr6.Port = windows.Ntohs(port) + addr6.Port = htons(port) addr6.Flowinfo = 0 copy(addr6.Addr[:], v6) addr6.Scope_id = 0 @@ -761,8 +772,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { return windows.ERROR_INVALID_PARAMETER } -// IP method returns IPv4 or IPv6 address. -// If the address is neither IPv4 not IPv6 nil is returned. +// IP returns IPv4 or IPv6 address, or nil if the address is neither. func (addr *RawSockaddrInet) IP() net.IP { switch addr.Family { case windows.AF_INET: @@ -775,6 +785,19 @@ func (addr *RawSockaddrInet) IP() net.IP { return nil } +// Port returns the port if the address if IPv4 or IPv6, or 0 if neither. +func (addr *RawSockaddrInet) Port() uint16 { + switch addr.Family { + case windows.AF_INET: + return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port) + + case windows.AF_INET6: + return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port) + } + + return 0 +} + // Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer. // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeunicastipaddressentry func (row *MibUnicastIPAddressRow) Init() { |