aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-07-12 15:53:10 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-08-02 19:10:58 +0200
commit5409c45a10dc7a045197bc4105c6a7bd5d29283f (patch)
treee64f9e7e09a4f3d965659413487781f452800256 /tunnel
parentversion: bump (diff)
downloadwireguard-windows-5409c45a10dc7a045197bc4105c6a7bd5d29283f.tar.xz
wireguard-windows-5409c45a10dc7a045197bc4105c6a7bd5d29283f.zip
driver: introduce new module for talking with kernel driver
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tunnel')
-rw-r--r--tunnel/addressconfig.go14
-rw-r--r--tunnel/defaultroutemonitor.go15
-rw-r--r--tunnel/interfacewatcher.go45
-rw-r--r--tunnel/mtumonitor.go113
-rw-r--r--tunnel/service.go175
-rw-r--r--tunnel/winipcfg/types.go31
6 files changed, 305 insertions, 88 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index 0dec95d0..fba7d770 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -12,8 +12,6 @@ import (
"sort"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -57,9 +55,7 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
}
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
- luid := winipcfg.LUID(tun.LUID())
-
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID, clamper mtuClamper) error {
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
@@ -151,7 +147,9 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t
}
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- tun.ForceMTU(int(ipif.NLMTU))
+ if clamper != nil {
+ clamper.ForceMTU(int(ipif.NLMTU))
+ }
}
if family == windows.AF_INET {
if foundDefault4 {
@@ -174,7 +172,7 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t
return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
}
-func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
+func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
doNotRestrict := true
if len(conf.Peers) == 1 && !conf.Interface.TableOff {
nextallowedip:
@@ -191,5 +189,5 @@ func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
}
}
log.Println("Enabling firewall rules")
- return firewall.EnableFirewall(tun.LUID(), doNotRestrict, conf.Interface.DNS)
+ return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
}
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
index aa0db675..ac4241c9 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/defaultroutemonitor.go
@@ -11,9 +11,7 @@ import (
"time"
"golang.org/x/sys/windows"
-
"golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
@@ -61,14 +59,17 @@ func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInte
return nil
}
-func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) {
+type mtuClamper interface {
+ ForceMTU(mtu int)
+}
+
+func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, clamper mtuClamper, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
var minMTU uint32
if family == windows.AF_INET {
minMTU = 576
} else if family == windows.AF_INET6 {
minMTU = 1280
}
- ourLUID := winipcfg.LUID(tun.LUID())
lastLUID := winipcfg.LUID(0)
lastIndex := ^uint32(0)
lastMTU := uint32(0)
@@ -103,7 +104,11 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT
if err != nil {
return err
}
- tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad.
+
+ // Having one MTU for both v4 and v6 kind of breaks the Windows model, so right now this just gets the
+ // second one which looks bad. However, internally, it doesn't seem like the Windows stack differentiates
+ // anyway, so it's probably fine.
+ clamper.ForceMTU(int(iface.NLMTU))
lastMTU = mtu
}
return nil
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index e12e5929..32132e93 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -12,8 +12,6 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
@@ -31,9 +29,10 @@ type interfaceWatcherEvent struct {
type interfaceWatcher struct {
errors chan interfaceWatcherError
- binder conn.BindSocketToInterface
- conf *conf.Config
- tun *tun.NativeTun
+ binder conn.BindSocketToInterface
+ clamper mtuClamper
+ conf *conf.Config
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
@@ -100,15 +99,24 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.binder != nil && iw.clamper != nil {
+ log.Printf("Monitoring default %s routes", ipversion)
+ *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
+ return
+ }
+ } else if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid, iw.clamper)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
@@ -127,11 +135,11 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
@@ -142,13 +150,13 @@ func watchInterface() (*interfaceWatcher, error) {
return iw, nil
}
-func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
- iw.binder, iw.conf, iw.tun = binder, conf, tun
+ iw.binder, iw.clamper, iw.conf, iw.luid = binder, clamper, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -160,7 +168,7 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,10 +194,9 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
luid.FlushDNS(windows.AF_INET)
diff --git a/tunnel/mtumonitor.go b/tunnel/mtumonitor.go
new file mode 100644
index 00000000..766ca1b8
--- /dev/null
+++ b/tunnel/mtumonitor.go
@@ -0,0 +1,113 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error {
+ r, err := winipcfg.GetIPForwardTable2(family)
+ if err != nil {
+ return err
+ }
+ lowestMetric := ^uint32(0)
+ index := uint32(0)
+ luid := winipcfg.LUID(0)
+ for i := range r {
+ if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
+ continue
+ }
+ ifrow, err := r[i].InterfaceLUID.Interface()
+ if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
+ continue
+ }
+
+ iface, err := r[i].InterfaceLUID.IPInterface(family)
+ if err != nil {
+ continue
+ }
+
+ if r[i].Metric+iface.Metric < lowestMetric {
+ lowestMetric = r[i].Metric + iface.Metric
+ index = r[i].InterfaceIndex
+ luid = r[i].InterfaceLUID
+ }
+ }
+ if luid == *lastLUID && index == *lastIndex {
+ return nil
+ }
+ *lastLUID = luid
+ *lastIndex = index
+ return nil
+}
+
+func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
+ var minMTU uint32
+ if family == windows.AF_INET {
+ minMTU = 576
+ } else if family == windows.AF_INET6 {
+ minMTU = 1280
+ }
+ lastLUID := winipcfg.LUID(0)
+ lastIndex := ^uint32(0)
+ lastMTU := uint32(0)
+ doIt := func() error {
+ err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex)
+ if err != nil {
+ return err
+ }
+ mtu := uint32(0)
+ if lastLUID != 0 {
+ iface, err := lastLUID.Interface()
+ if err != nil {
+ return err
+ }
+ if iface.MTU > 0 {
+ mtu = iface.MTU
+ }
+ }
+ if mtu > 0 && lastMTU != mtu {
+ iface, err := ourLUID.IPInterface(family)
+ if err != nil {
+ return err
+ }
+ iface.NLMTU = mtu - 80
+ if iface.NLMTU < minMTU {
+ iface.NLMTU = minMTU
+ }
+ err = iface.Set()
+ if err != nil {
+ return err
+ }
+ lastMTU = mtu
+ }
+ return nil
+ }
+ err := doIt()
+ if err != nil {
+ return nil, err
+ }
+ cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
+ if route != nil && route.DestinationPrefix.PrefixLength == 0 {
+ doIt()
+ }
+ })
+ if err != nil {
+ return nil, err
+ }
+ cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ if notificationType == winipcfg.MibParameterNotification {
+ doIt()
+ }
+ })
+ if err != nil {
+ cbr.Unregister()
+ return nil, err
+ }
+ return []winipcfg.ChangeCallback{cbr, cbi}, nil
+}
diff --git a/tunnel/service.go b/tunnel/service.go
index 63cd243f..a595994c 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -21,11 +21,12 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.zx2c4.com/wireguard/windows/version"
)
@@ -40,6 +41,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
var uapi net.Listener
var watcher *interfaceWatcher
var nativeTun *tun.NativeTun
+ var wintun tun.Device
+ var adapter *driver.Adapter
+ var luid winipcfg.LUID
var config *conf.Config
var err error
serviceError := services.ErrorSuccess
@@ -127,10 +131,10 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
if m, err := mgr.Connect(); err == nil {
if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
- /* If we don't do this, then the Wintun installation will block forever, because
- * installing a Wintun device starts a service too. Apparently at boot time, Windows
- * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce
- * that we're running before starting additional services.
+ /* If we don't do this, then the driver installation will block forever, because
+ * installing a network adapter starts the driver service too. Apparently at boot time,
+ * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't
+ * announce that we're running before starting additional services.
*/
log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
changes <- svc.Status{State: svc.Running}
@@ -146,34 +150,81 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
log.Println("Resolving DNS names")
- uapiConf, err := config.ToUAPI()
+ err = config.ResolveEndpoints()
if err != nil {
serviceError = services.ErrorDNSLookup
return
}
- log.Println("Creating Wintun interface")
- var wintun tun.Device
- for i := 0; i < 5; i++ {
- if i > 0 {
- time.Sleep(time.Second)
- log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ log.Println("Creating network adapter")
+ if conf.AdminBool("ExperimentalKernelDriver") {
+ // Does an adapter with this name already exist?
+ adapter, err = driver.DefaultPool.OpenAdapter(config.Name)
+ if err == nil {
+ // If so, we delete it, in case it has weird residual configuration.
+ _, err = adapter.Delete()
+ if err != nil {
+ err = fmt.Errorf("Error deleting already existing adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
}
- wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
- if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
- break
+ for i := 0; i < 5; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ }
+ var rebootRequired bool
+ adapter, rebootRequired, err = driver.DefaultPool.CreateAdapter(config.Name, deterministicGUID(config))
+ if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
+ if rebootRequired {
+ log.Println("Windows indicated a reboot is required.")
+ }
+ break
+ }
+ }
+ if err != nil {
+ err = fmt.Errorf("Error creating adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
+ defer adapter.Delete()
+ luid = adapter.LUID()
+ driverVersion, err := driver.RunningVersion()
+ if err != nil {
+ log.Printf("Warning: unable to determine driver version: %v", err)
+ } else {
+ log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
+ err = adapter.SetLogging(driver.AdapterLogOn)
+ if err != nil {
+ err = fmt.Errorf("Error enabling adapter logging: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
}
- }
- if err != nil {
- serviceError = services.ErrorCreateWintun
- return
- }
- nativeTun = wintun.(*tun.NativeTun)
- wintunVersion, err := nativeTun.RunningVersion()
- if err != nil {
- log.Printf("Warning: unable to determine Wintun version: %v", err)
} else {
- log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff)
+ for i := 0; i < 5; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ }
+ wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
+ if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
+ break
+ }
+ }
+ if err != nil {
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
+ nativeTun = wintun.(*tun.NativeTun)
+ luid = winipcfg.LUID(nativeTun.LUID())
+ driverVersion, err := nativeTun.RunningVersion()
+ if err != nil {
+ log.Printf("Warning: unable to determine driver version: %v", err)
+ } else {
+ log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
}
err = runScriptCommand(config.Interface.PreUp, config.Name)
@@ -182,7 +233,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- err = enableFirewall(config, nativeTun)
+ err = enableFirewall(config, luid)
if err != nil {
serviceError = services.ErrorFirewall
return
@@ -195,37 +246,51 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- log.Println("Creating interface instance")
- bind := conn.NewDefaultBind()
- dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
+ if nativeTun != nil {
+ log.Println("Creating interface instance")
+ bind := conn.NewDefaultBind()
+ dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
- log.Println("Setting interface configuration")
- uapi, err = ipc.UAPIListen(config.Name)
- if err != nil {
- serviceError = services.ErrorUAPIListen
- return
- }
- err = dev.IpcSet(uapiConf)
- if err != nil {
- serviceError = services.ErrorDeviceSetConfig
- return
- }
+ log.Println("Setting interface configuration")
+ uapi, err = ipc.UAPIListen(config.Name)
+ if err != nil {
+ serviceError = services.ErrorUAPIListen
+ return
+ }
+ err = dev.IpcSet(config.ToUAPI())
+ if err != nil {
+ serviceError = services.ErrorDeviceSetConfig
+ return
+ }
- log.Println("Bringing peers up")
- dev.Up()
+ log.Println("Bringing peers up")
+ dev.Up()
- watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun)
+ var clamper mtuClamper
+ clamper = nativeTun
+ watcher.Configure(bind.(conn.BindSocketToInterface), clamper, config, luid)
- log.Println("Listening for UAPI requests")
- go func() {
- for {
- conn, err := uapi.Accept()
- if err != nil {
- continue
+ log.Println("Listening for UAPI requests")
+ go func() {
+ for {
+ conn, err := uapi.Accept()
+ if err != nil {
+ continue
+ }
+ go dev.IpcHandle(conn)
}
- go dev.IpcHandle(conn)
+ }()
+ } else {
+ err = adapter.SetConfiguration(config.ToDriverConfiguration())
+ if err != nil {
+ serviceError = services.ErrorDeviceSetConfig
}
- }()
+ err = adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ serviceError = services.ErrorDeviceBringUp
+ }
+ watcher.Configure(nil, nil, config, luid)
+ }
err = runScriptCommand(config.Interface.PostUp, config.Name)
if err != nil {
@@ -236,6 +301,12 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
log.Println("Startup complete")
+ var devWaitChan chan struct{}
+ if dev != nil {
+ devWaitChan = dev.Wait()
+ } else {
+ devWaitChan = make(chan struct{})
+ }
for {
select {
case c := <-r:
@@ -247,7 +318,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
default:
log.Printf("Unexpected service control request #%d\n", c)
}
- case <-dev.Wait():
+ case <-devWaitChan:
return
case e := <-watcher.errors:
serviceError, err = e.serviceError, e.err
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 4dc52d8b..b06f05dd 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -6,6 +6,7 @@
package winipcfg
import (
+ "encoding/binary"
"net"
"unsafe"
@@ -734,6 +735,16 @@ type RawSockaddrInet struct {
data [26]byte
}
+func ntohs(i uint16) uint16 {
+ return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:])
+}
+
+func htons(i uint16) uint16 {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, i)
+ return *(*uint16)(unsafe.Pointer(&b[0]))
+}
+
// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
@@ -741,7 +752,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
copy(addr4.Addr[:], v4)
- addr4.Port = windows.Ntohs(port)
+ addr4.Port = htons(port)
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
@@ -751,7 +762,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
if v6 := ip.To16(); v6 != nil {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
- addr6.Port = windows.Ntohs(port)
+ addr6.Port = htons(port)
addr6.Flowinfo = 0
copy(addr6.Addr[:], v6)
addr6.Scope_id = 0
@@ -761,8 +772,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
return windows.ERROR_INVALID_PARAMETER
}
-// IP method returns IPv4 or IPv6 address.
-// If the address is neither IPv4 not IPv6 nil is returned.
+// IP returns IPv4 or IPv6 address, or nil if the address is neither.
func (addr *RawSockaddrInet) IP() net.IP {
switch addr.Family {
case windows.AF_INET:
@@ -775,6 +785,19 @@ func (addr *RawSockaddrInet) IP() net.IP {
return nil
}
+// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
+func (addr *RawSockaddrInet) Port() uint16 {
+ switch addr.Family {
+ case windows.AF_INET:
+ return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
+
+ case windows.AF_INET6:
+ return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
+ }
+
+ return 0
+}
+
// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeunicastipaddressentry
func (row *MibUnicastIPAddressRow) Init() {