aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-11-02 17:32:47 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-11-06 14:36:57 +0100
commit25d879e1ae8fb8c7aebd18829ab2ef9fbc8ac9fa (patch)
tree55a3d09f403adaa57c69d971599ac93a9e9c89ae
parentmanager: cleanup legacy wintun (diff)
downloadwireguard-windows-25d879e1ae8fb8c7aebd18829ab2ef9fbc8ac9fa.tar.xz
wireguard-windows-25d879e1ae8fb8c7aebd18829ab2ef9fbc8ac9fa.zip
global: switch to netip
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--conf/config.go62
-rw-r--r--conf/dnsresolver_windows.go22
-rw-r--r--conf/parser.go73
-rw-r--r--conf/parser_test.go8
-rw-r--r--conf/writer.go29
-rw-r--r--go.mod7
-rw-r--r--go.mod.master1
-rw-r--r--go.sum8
-rw-r--r--tunnel/addressconfig.go89
-rw-r--r--tunnel/deterministicguid.go8
-rw-r--r--tunnel/firewall/blocker.go5
-rw-r--r--tunnel/firewall/rules.go15
-rw-r--r--tunnel/winipcfg/luid.go71
-rw-r--r--tunnel/winipcfg/netsh.go13
-rw-r--r--tunnel/winipcfg/types.go75
-rw-r--r--tunnel/winipcfg/winipcfg_test.go83
-rw-r--r--ui/editdialog.go58
17 files changed, 260 insertions, 367 deletions
diff --git a/conf/config.go b/conf/config.go
index b600456d..e6cc6aac 100644
--- a/conf/config.go
+++ b/conf/config.go
@@ -10,10 +10,11 @@ import (
"crypto/subtle"
"encoding/base64"
"fmt"
- "net"
"strings"
"time"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/windows/l18n"
@@ -21,11 +22,6 @@ import (
const KeyLength = 32
-type IPCidr struct {
- IP net.IP
- Cidr uint8
-}
-
type Endpoint struct {
Host string
Port uint16
@@ -43,10 +39,10 @@ type Config struct {
type Interface struct {
PrivateKey Key
- Addresses []IPCidr
+ Addresses []netip.Prefix
ListenPort uint16
MTU uint16
- DNS []net.IP
+ DNS []netip.Addr
DNSSearch []string
PreUp string
PostUp string
@@ -58,7 +54,7 @@ type Interface struct {
type Peer struct {
PublicKey Key
PresharedKey Key
- AllowedIPs []IPCidr
+ AllowedIPs []netip.Prefix
Endpoint Endpoint
PersistentKeepalive uint16
@@ -67,62 +63,28 @@ type Peer struct {
LastHandshakeTime HandshakeTime
}
-func (r *IPCidr) String() string {
- return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
-}
-
-func (r *IPCidr) Bits() uint8 {
- if r.IP.To4() != nil {
- return 32
- }
- return 128
-}
-
-func (r *IPCidr) IPNet() net.IPNet {
- return net.IPNet{
- IP: r.IP,
- Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
- }
-}
-
-func (r *IPCidr) MaskSelf() {
- bits := int(r.Bits())
- mask := net.CIDRMask(int(r.Cidr), bits)
- for i := 0; i < bits/8; i++ {
- r.IP[i] &= mask[i]
- }
-}
-
func (conf *Config) IntersectsWith(other *Config) bool {
- type hashableIPCidr struct {
- ip string
- cidr byte
- }
- allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
+ allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
for _, a := range conf.Interface.Addresses {
- allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true
- a.MaskSelf()
- allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
+ allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true
+ allRoutes[a.Masked()] = true
}
for i := range conf.Peers {
for _, a := range conf.Peers[i].AllowedIPs {
- a.MaskSelf()
- allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
+ allRoutes[a.Masked()] = true
}
}
for _, a := range other.Interface.Addresses {
- if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] {
+ if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] {
return true
}
- a.MaskSelf()
- if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
+ if allRoutes[a.Masked()] {
return true
}
}
for i := range other.Peers {
for _, a := range other.Peers[i].AllowedIPs {
- a.MaskSelf()
- if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
+ if allRoutes[a.Masked()] {
return true
}
}
diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go
index 094b1029..3f7f7df8 100644
--- a/conf/dnsresolver_windows.go
+++ b/conf/dnsresolver_windows.go
@@ -6,13 +6,13 @@
package conf
import (
- "fmt"
"log"
- "net"
- "syscall"
+ "strconv"
"time"
"unsafe"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/services"
)
@@ -66,24 +66,24 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
return
}
defer windows.FreeAddrInfoW(result)
- ipv6 := ""
+ var v6 netip.Addr
for ; result != nil; result = result.Next {
switch result.Family {
case windows.AF_INET:
- return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
+ return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr).String(), nil
case windows.AF_INET6:
- if len(ipv6) != 0 {
+ if v6.IsValid() {
continue
}
- a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
- ipv6 = (net.IP)(a.Addr[:]).String()
+ a := (*windows.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
+ v6 = netip.AddrFrom16(a.Addr)
if a.Scope_id != 0 {
- ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
+ v6 = v6.WithZone(strconv.FormatUint(uint64(a.Scope_id), 10))
}
}
}
- if len(ipv6) != 0 {
- return ipv6, nil
+ if v6.IsValid() {
+ return v6.String(), nil
}
err = windows.WSAHOST_NOT_FOUND
return
diff --git a/conf/parser.go b/conf/parser.go
index 83f25964..477a5205 100644
--- a/conf/parser.go
+++ b/conf/parser.go
@@ -7,10 +7,11 @@ package conf
import (
"encoding/base64"
- "net"
"strconv"
"strings"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
"golang.org/x/text/encoding/unicode"
@@ -27,43 +28,16 @@ func (e *ParseError) Error() string {
return l18n.Sprintf("%s: %q", e.why, e.offender)
}
-func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
- var addrStr, cidrStr string
- var cidr int
-
- i := strings.IndexByte(s, '/')
- if i < 0 {
- addrStr = s
- } else {
- addrStr, cidrStr = s[:i], s[i+1:]
- }
-
- err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
- addr := net.ParseIP(addrStr)
- if addr == nil {
- return
- }
- maybeV4 := addr.To4()
- if maybeV4 != nil {
- addr = maybeV4
+func parseIPCidr(s string) (netip.Prefix, error) {
+ ipcidr, err := netip.ParsePrefix(s)
+ if err == nil {
+ return ipcidr, nil
}
- if len(cidrStr) > 0 {
- err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
- cidr, err = strconv.Atoi(cidrStr)
- if err != nil || cidr < 0 || cidr > 128 {
- return
- }
- if cidr > 32 && maybeV4 != nil {
- return
- }
- } else {
- if maybeV4 != nil {
- cidr = 32
- } else {
- cidr = 128
- }
+ addr, err := netip.ParseAddr(s)
+ if err != nil {
+ return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
}
- return &IPCidr{addr, uint8(cidr)}, nil
+ return netip.PrefixFrom(addr, addr.BitLen()), nil
}
func parseEndpoint(s string) (*Endpoint, error) {
@@ -87,8 +61,8 @@ func parseEndpoint(s string) (*Endpoint, error) {
if i := strings.LastIndexByte(host, '%'); i > 1 {
end = i
}
- maybeV6 := net.ParseIP(host[1:end])
- if maybeV6 == nil || len(maybeV6) != net.IPv6len {
+ maybeV6, err2 := netip.ParseAddr(host[1:end])
+ if err2 != nil || !maybeV6.Is6() {
return nil, err
}
} else {
@@ -96,7 +70,7 @@ func parseEndpoint(s string) (*Endpoint, error) {
}
host = host[1 : len(host)-1]
}
- return &Endpoint{host, uint16(port)}, nil
+ return &Endpoint{host, port}, nil
}
func parseMTU(s string) (uint16, error) {
@@ -256,7 +230,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
- conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
+ conf.Interface.Addresses = append(conf.Interface.Addresses, a)
}
case "dns":
addresses, err := splitList(val)
@@ -264,8 +238,8 @@ func FromWgQuick(s string, name string) (*Config, error) {
return nil, err
}
for _, address := range addresses {
- a := net.ParseIP(address)
- if a == nil {
+ a, err := netip.ParseAddr(address)
+ if err != nil {
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
} else {
conf.Interface.DNS = append(conf.Interface.DNS, a)
@@ -312,7 +286,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
- peer.AllowedIPs = append(peer.AllowedIPs, *a)
+ peer.AllowedIPs = append(peer.AllowedIPs, a)
}
case "persistentkeepalive":
p, err := parsePersistentKeepalive(val)
@@ -399,7 +373,7 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
}
if p.Flags&driver.PeerHasEndpoint != 0 {
peer.Endpoint.Port = p.Endpoint.Port()
- peer.Endpoint.Host = p.Endpoint.IP().String()
+ peer.Endpoint.Host = p.Endpoint.Addr().String()
}
if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
peer.PersistentKeepalive = p.PersistentKeepalive
@@ -416,16 +390,13 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
} else {
a = a.NextAllowedIP()
}
- var ip net.IP
+ var ip netip.Addr
if a.AddressFamily == windows.AF_INET {
- ip = a.Address[:4]
+ ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
} else if a.AddressFamily == windows.AF_INET6 {
- ip = a.Address[:16]
+ ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
}
- peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{
- IP: ip,
- Cidr: a.Cidr,
- })
+ peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr)))
}
conf.Peers = append(conf.Peers, peer)
}
diff --git a/conf/parser_test.go b/conf/parser_test.go
index f80d6d18..c2f757c7 100644
--- a/conf/parser_test.go
+++ b/conf/parser_test.go
@@ -6,10 +6,11 @@
package conf
import (
- "net"
"reflect"
"runtime"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
const testInput = `
@@ -77,10 +78,9 @@ func contains(t *testing.T, list, element interface{}) bool {
func TestFromWgQuick(t *testing.T) {
conf, err := FromWgQuick(testInput, "test")
if noError(t, err) {
-
lenTest(t, conf.Interface.Addresses, 2)
- contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
- contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
+ contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16))
+ contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24))
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
equal(t, uint16(51820), conf.Interface.ListenPort)
diff --git a/conf/writer.go b/conf/writer.go
index 3e24559f..61abb672 100644
--- a/conf/writer.go
+++ b/conf/writer.go
@@ -7,10 +7,11 @@ package conf
import (
"fmt"
- "net"
"strings"
"unsafe"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -111,8 +112,11 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
}
var endpoint winipcfg.RawSockaddrInet
if !config.Peers[i].Endpoint.IsEmpty() {
- flags |= driver.PeerHasEndpoint
- endpoint.SetIP(net.ParseIP(config.Peers[i].Endpoint.Host), config.Peers[i].Endpoint.Port)
+ addr, err := netip.ParseAddr(config.Peers[i].Endpoint.Host)
+ if err == nil {
+ flags |= driver.PeerHasEndpoint
+ endpoint.SetAddrPort(netip.AddrPortFrom(addr, config.Peers[i].Endpoint.Port))
+ }
}
c.AppendPeer(&driver.Peer{
Flags: flags,
@@ -123,20 +127,13 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)),
})
for j := range config.Peers[i].AllowedIPs {
- var family winipcfg.AddressFamily
- var ip net.IP
- if ip = config.Peers[i].AllowedIPs[j].IP.To4(); ip != nil {
- family = windows.AF_INET
- } else if ip = config.Peers[i].AllowedIPs[j].IP.To16(); ip != nil {
- family = windows.AF_INET6
- } else {
- ip = config.Peers[i].AllowedIPs[j].IP
- }
- a := &driver.AllowedIP{
- AddressFamily: family,
- Cidr: config.Peers[i].AllowedIPs[j].Cidr,
+ a := &driver.AllowedIP{Cidr: uint8(config.Peers[i].AllowedIPs[j].Bits())}
+ copy(a.Address[:], config.Peers[i].AllowedIPs[j].Addr().AsSlice())
+ if config.Peers[i].AllowedIPs[j].Addr().Is4() {
+ a.AddressFamily = windows.AF_INET
+ } else if config.Peers[i].AllowedIPs[j].Addr().Is6() {
+ a.AddressFamily = windows.AF_INET6
}
- copy(a.Address[:], ip)
c.AppendAllowedIP(a)
}
}
diff --git a/go.mod b/go.mod
index 55779a43..084d8578 100644
--- a/go.mod
+++ b/go.mod
@@ -6,9 +6,10 @@ require (
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794
github.com/lxn/win v0.0.0-20210218163916-a377121e959e
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
- golang.org/x/net v0.0.0-20211029160332-540bb53d3b2e
- golang.org/x/sys v0.0.0-20211029165221-6e7872819dc8
- golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956
+ golang.org/x/net v0.0.0-20211105192438-b53810dc28af
+ golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68
+ golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2
+ golang.zx2c4.com/go118/netip v0.0.0-20211106132939-9d41d90554dd
)
require (
diff --git a/go.mod.master b/go.mod.master
index b106dfea..897e97fa 100644
--- a/go.mod.master
+++ b/go.mod.master
@@ -9,6 +9,7 @@ require (
golang.org/x/net latest
golang.org/x/sys latest
golang.org/x/text master
+ golang.zx2c4.com/go118/netip master
)
replace (
diff --git a/go.sum b/go.sum
index 58d5ab14..f10dbdd8 100644
--- a/go.sum
+++ b/go.sum
@@ -6,11 +6,11 @@ golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20211029160332-540bb53d3b2e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20211105192438-b53810dc28af/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956 h1:xw/3G76i8BwoCoEZ8RzhVpFrHEz4Qm9D7zPckwa7KVM=
-golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0=
+golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 h1:GLw7MR8AfAG2GmGcmVgObFOHXYypgGjnGno25RDwn3Y=
+golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ=
golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
@@ -18,6 +18,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.zx2c4.com/go118/netip v0.0.0-20211106132939-9d41d90554dd h1:gUHae7sCd+tFJLcCximWeBFD2b6Jg3O7UaNaPvjIJHc=
+golang.zx2c4.com/go118/netip v0.0.0-20211106132939-9d41d90554dd/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard/windows v0.0.0-20210121140954-e7fc19d483bd h1:kAUzMAITME2MCtrXBaUa9P4tndiXGWO674k9gn6ZR28=
golang.zx2c4.com/wireguard/windows v0.0.0-20210121140954-e7fc19d483bd/go.mod h1:Y+FYqVFaQO6a+1uigm0N0GiuaZrLEaBxEiJ8tfH9sMQ=
golang.zx2c4.com/wireguard/windows v0.0.0-20210224134948-620c54ef6199 h1:ogXKLng/Myrt2odYTkleySGzQj/GWg9GV1AQ8P9NnU4=
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index f315cd15..3068fdbb 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -6,13 +6,12 @@
package tunnel
import (
- "bytes"
"fmt"
"log"
- "net"
- "sort"
"time"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/services"
@@ -20,19 +19,13 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
+func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) {
if len(addresses) == 0 {
return
}
- addrToStr := func(ip *net.IP) string {
- if ip4 := ip.To4(); ip4 != nil {
- return string(ip4)
- }
- return string(*ip)
- }
- addrHash := make(map[string]bool, len(addresses))
+ addrHash := make(map[netip.Addr]bool, len(addresses))
for i := range addresses {
- addrHash[addrToStr(&addresses[i].IP)] = true
+ addrHash[addresses[i].Addr()] = true
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
@@ -43,11 +36,10 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
- ip := address.Address.IP()
- if addrHash[addrToStr(&ip)] {
- ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
- log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName())
- iface.LUID.DeleteIPAddress(ipnet)
+ if ip := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] {
+ prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
+ log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", prefix.String(), iface.FriendlyName())
+ iface.LUID.DeleteIPAddress(prefix)
}
}
}
@@ -69,14 +61,14 @@ startOver:
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
}
- routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
- addresses := make([]net.IPNet, len(conf.Interface.Addresses))
+ routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount)
+ addresses := make([]netip.Prefix, len(conf.Interface.Addresses))
var haveV4Address, haveV6Address bool
for i, addr := range conf.Interface.Addresses {
- addresses[i] = addr.IPNet()
- if addr.Bits() == 32 {
+ addresses[i] = addr
+ if addr.Addr().Is4() {
haveV4Address = true
- } else if addr.Bits() == 128 {
+ } else if addr.Addr().Is6() {
haveV6Address = true
}
}
@@ -85,53 +77,32 @@ startOver:
foundDefault6 := false
for _, peer := range conf.Peers {
for _, allowedip := range peer.AllowedIPs {
- allowedip.MaskSelf()
- if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
+ if (allowedip.Addr().Is4() && !haveV4Address) || (allowedip.Addr().Is6() && !haveV6Address) {
continue
}
route := winipcfg.RouteData{
- Destination: allowedip.IPNet(),
+ Destination: allowedip.Masked(),
Metric: 0,
}
- if allowedip.Bits() == 32 {
- if allowedip.Cidr == 0 {
+ if allowedip.Addr().Is4() {
+ if allowedip.Bits() == 0 {
foundDefault4 = true
}
- route.NextHop = net.IPv4zero
- } else if allowedip.Bits() == 128 {
- if allowedip.Cidr == 0 {
+ route.NextHop = netip.IPv4Unspecified()
+ } else if allowedip.Addr().Is6() {
+ if allowedip.Bits() == 0 {
foundDefault6 = true
}
- route.NextHop = net.IPv6zero
+ route.NextHop = netip.IPv6Unspecified()
}
- routes = append(routes, route)
+ routes[route] = true
}
}
deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
- sort.Slice(routes, func(i, j int) bool {
- if routes[i].Metric != routes[j].Metric {
- return routes[i].Metric < routes[j].Metric
- }
- if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 {
- return c < 0
- }
- if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 {
- return c < 0
- }
- if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 {
- return c < 0
- }
- return false
- })
- for i := 0; i < len(routes); i++ {
- if i > 0 && routes[i].Metric == routes[i-1].Metric &&
- bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
- bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
- bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
- continue
- }
- deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
+ for route := range routes {
+ r := route
+ deduplicatedRoutes = append(deduplicatedRoutes, &r)
}
if !conf.Interface.TableOff {
@@ -189,14 +160,8 @@ startOver:
func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
doNotRestrict := true
if len(conf.Peers) == 1 && !conf.Interface.TableOff {
- nextallowedip:
for _, allowedip := range conf.Peers[0].AllowedIPs {
- if allowedip.Cidr == 0 {
- for _, b := range allowedip.IP {
- if b != 0 {
- continue nextallowedip
- }
- }
+ if allowedip.Bits() == 0 && allowedip == allowedip.Masked() {
doNotRestrict = false
break
}
diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go
index 455deaeb..afdab11e 100644
--- a/tunnel/deterministicguid.go
+++ b/tunnel/deterministicguid.go
@@ -80,13 +80,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID {
b2Number(len(peer.AllowedIPs))
sortedAllowedIPs := peer.AllowedIPs
sort.Slice(sortedAllowedIPs, func(i, j int) bool {
- if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj {
+ if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj {
return bi < bj
}
- if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr {
- return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr
+ if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() {
+ return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits()
}
- return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0
+ return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0
})
for _, allowedip := range sortedAllowedIPs {
b2String(allowedip.String())
diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go
index 2cb2c7f2..4be62aa9 100644
--- a/tunnel/firewall/blocker.go
+++ b/tunnel/firewall/blocker.go
@@ -7,9 +7,10 @@ package firewall
import (
"errors"
- "net"
"unsafe"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
)
@@ -101,7 +102,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) {
return bo, nil
}
-func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []net.IP) error {
+func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error {
if wfpSession != 0 {
return errors.New("The firewall has already been enabled")
}
diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go
index c4488a31..201a73e3 100644
--- a/tunnel/firewall/rules.go
+++ b/tunnel/firewall/rules.go
@@ -8,10 +8,11 @@ package firewall
import (
"encoding/binary"
"errors"
- "net"
"runtime"
"unsafe"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
)
@@ -985,7 +986,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error {
}
// Block all DNS traffic except towards specified DNS servers.
-func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error {
+func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error {
if weightDeny >= weightAllow {
return errors.New("The allow weight must be greater than the deny weight")
}
@@ -1106,8 +1107,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV4 = append(allowConditionsV4, denyConditions...)
for _, ip := range except {
- ip4 := ip.To4()
- if ip4 == nil {
+ if !ip.Is4() {
continue
}
allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{
@@ -1115,7 +1115,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
- value: uintptr(binary.BigEndian.Uint32(ip4)),
+ value: uintptr(binary.BigEndian.Uint32(ip.AsSlice())),
},
})
}
@@ -1124,11 +1124,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV6 = append(allowConditionsV6, denyConditions...)
for _, ip := range except {
- if ip.To4() != nil {
+ if !ip.Is6() {
continue
}
- var address wtFwpByteArray16
- copy(address.byteArray16[:], ip)
+ address := wtFwpByteArray16{byteArray16: ip.As16()}
allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go
index ca388acc..20b5a38b 100644
--- a/tunnel/winipcfg/luid.go
+++ b/tunnel/winipcfg/luid.go
@@ -7,9 +7,10 @@ package winipcfg
import (
"errors"
- "net"
"strings"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
)
@@ -76,10 +77,10 @@ func LUIDFromIndex(index uint32) (LUID, error) {
// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry)
-func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
+func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) {
row := &MibUnicastIPAddressRow{InterfaceLUID: luid}
- err := row.Address.SetIP(ip, 0)
+ err := row.Address.SetAddr(addr)
if err != nil {
return nil, err
}
@@ -94,25 +95,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddress(address net.IPNet) error {
+func (luid LUID) AddIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
row.DadState = DadStatePreferred
row.ValidLifetime = 0xffffffff
row.PreferredLifetime = 0xffffffff
- err := row.Address.SetIP(address.IP, 0)
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Create()
}
// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error {
for i := range addresses {
err := luid.AddIPAddress(addresses[i])
if err != nil {
@@ -123,7 +123,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddresses method sets new unicast IP addresses to the interface.
-func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(windows.AF_UNSPEC)
if err != nil {
return err
@@ -132,16 +132,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
-func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
+func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(family)
if err != nil {
return err
}
for i := range addresses {
- asV4 := addresses[i].IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !addresses[i].Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddIPAddress(addresses[i])
@@ -154,17 +153,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
-func (luid LUID) DeleteIPAddress(address net.IPNet) error {
+func (luid LUID) DeleteIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
- err := row.Address.SetIP(address.IP, 0)
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
// Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry().
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Delete()
}
@@ -188,17 +186,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error {
// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2).
// NOTE: If the corresponding route isn't found, the method will return error.
-func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) {
+func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
row.ValidLifetime = 0xffffffff
row.PreferredLifetime = 0xffffffff
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return nil, err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return nil, err
}
@@ -212,15 +210,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2
// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature.
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2)
-func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error {
+func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -255,10 +253,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
return err
}
for _, rd := range routesData {
- asV4 := rd.Destination.IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !rd.Destination.Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
@@ -271,15 +268,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
-func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
+func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -312,17 +309,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error {
}
// DNS method returns all DNS server addresses associated with the adapter.
-func (luid LUID) DNS() ([]net.IP, error) {
+func (luid LUID) DNS() ([]netip.Addr, error) {
addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
return nil, err
}
- r := make([]net.IP, 0, len(addresses))
+ r := make([]netip.Addr, 0, len(addresses))
for _, addr := range addresses {
if addr.LUID == luid {
for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next {
if ip := dns.Address.IP(); ip != nil {
- r = append(r, ip)
+ if a := netip.AddrFromSlice(ip); a.IsValid() {
+ r = append(r, a)
+ }
} else {
return nil, windows.ERROR_INVALID_PARAMETER
}
@@ -333,17 +332,15 @@ func (luid LUID) DNS() ([]net.IP, error) {
}
// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family.
-func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error {
+func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error {
if family != windows.AF_INET && family != windows.AF_INET6 {
return windows.ERROR_PROTOCOL_UNREACHABLE
}
var filteredServers []string
for _, server := range servers {
- if v4 := server.To4(); v4 != nil && family == windows.AF_INET {
- filteredServers = append(filteredServers, v4.String())
- } else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
- filteredServers = append(filteredServers, v6.String())
+ if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) {
+ filteredServers = append(filteredServers, server.String())
}
}
servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ","))
diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go
index 1f3d12d0..17e0778c 100644
--- a/tunnel/winipcfg/netsh.go
+++ b/tunnel/winipcfg/netsh.go
@@ -10,12 +10,13 @@ import (
"errors"
"fmt"
"io"
- "net"
"os/exec"
"path/filepath"
"strings"
"syscall"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
@@ -57,7 +58,7 @@ const (
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
)
-func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error {
+func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
var templateFlush string
if family == windows.AF_INET {
templateFlush = netshCmdTemplateFlush4
@@ -72,10 +73,10 @@ func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) e
}
cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
for i := 0; i < len(dnses); i++ {
- if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
- } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
+ if dnses[i].Is4() && family == windows.AF_INET {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String()))
+ } else if dnses[i].Is6() && family == windows.AF_INET6 {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String()))
}
}
return runNetsh(cmds)
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 789ee501..599bf789 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -8,9 +8,10 @@ package winipcfg
import (
"encoding/binary"
"fmt"
- "net"
"unsafe"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
)
@@ -584,8 +585,8 @@ const (
// RouteData structure describes a route to add
type RouteData struct {
- Destination net.IPNet
- NextHop net.IP
+ Destination netip.Prefix
+ NextHop netip.Addr
Metric uint32
}
@@ -748,44 +749,50 @@ func htons(i uint16) uint16 {
return *(*uint16)(unsafe.Pointer(&b[0]))
}
-// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
+// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
-func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
- if v4 := ip.To4(); v4 != nil {
+func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error {
+ if addrPort.Addr().Is4() {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
- copy(addr4.Addr[:], v4)
- addr4.Port = htons(port)
+ addr4.Addr = addrPort.Addr().As4()
+ addr4.Port = htons(addrPort.Port())
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
return nil
- }
-
- if v6 := ip.To16(); v6 != nil {
+ } else if addrPort.Addr().Is6() {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
- addr6.Port = htons(port)
+ addr6.Addr = addrPort.Addr().As16()
+ addr6.Port = htons(addrPort.Port())
addr6.Flowinfo = 0
- copy(addr6.Addr[:], v6)
addr6.Scope_id = 0
return nil
}
-
return windows.ERROR_INVALID_PARAMETER
}
-// IP returns IPv4 or IPv6 address, or nil if the address is neither.
-func (addr *RawSockaddrInet) IP() net.IP {
+// SetAddr method sets family and address to the given IPv4 or IPv6 address.
+// All other members of the structure are set to zero.
+func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error {
+ return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0))
+}
+
+// AddrPort returns the IP address and port.
+func (addr *RawSockaddrInet) AddrPort() netip.AddrPort {
+ return netip.AddrPortFrom(addr.Addr(), addr.Port())
+}
+
+// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither.
+func (addr *RawSockaddrInet) Addr() netip.Addr {
switch addr.Family {
case windows.AF_INET:
- return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:]
-
+ return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr)
case windows.AF_INET6:
- return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:]
+ return netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr)
}
-
- return nil
+ return netip.Addr{}
}
// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
@@ -793,11 +800,9 @@ func (addr *RawSockaddrInet) Port() uint16 {
switch addr.Family {
case windows.AF_INET:
return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
-
case windows.AF_INET6:
return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
}
-
return 0
}
@@ -874,32 +879,30 @@ func (tab *mibAnycastIPAddressTable) free() {
// IPAddressPrefix structure stores an IP address prefix.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix
type IPAddressPrefix struct {
- Prefix RawSockaddrInet
+ RawPrefix RawSockaddrInet
PrefixLength uint8
_ [2]byte
}
-// SetIPNet method sets IP address prefix using net.IPNet.
-func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error {
- err := prefix.Prefix.SetIP(net.IP, 0)
+// SetPrefix method sets IP address prefix using netip.Prefix.
+func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error {
+ err := prefix.RawPrefix.SetAddr(netPrefix.Addr())
if err != nil {
return err
}
- ones, _ := net.Mask.Size()
- prefix.PrefixLength = uint8(ones)
+ prefix.PrefixLength = uint8(netPrefix.Bits())
return nil
}
-// IPNet method returns IP address prefix as net.IPNet.
-// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately.
-func (prefix *IPAddressPrefix) IPNet() net.IPNet {
- switch prefix.Prefix.Family {
+// Prefix returns IP address prefix as netip.Prefix.
+func (prefix *IPAddressPrefix) Prefix() netip.Prefix {
+ switch prefix.RawPrefix.Family {
case windows.AF_INET:
- return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)}
+ return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
case windows.AF_INET6:
- return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)}
+ return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
}
- return net.IPNet{}
+ return netip.Prefix{}
}
// MibIPforwardRow2 structure stores information about an IP route entry.
diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go
index 5d3bc276..d863b1a2 100644
--- a/tunnel/winipcfg/winipcfg_test.go
+++ b/tunnel/winipcfg/winipcfg_test.go
@@ -22,13 +22,13 @@ Some tests in this file require:
package winipcfg
import (
- "bytes"
- "net"
"strings"
"syscall"
"testing"
"time"
+ "golang.zx2c4.com/go118/netip"
+
"golang.org/x/sys/windows"
)
@@ -38,22 +38,13 @@ const (
// TODO: Add IPv6 tests.
var (
- unexistentIPAddresToAdd = net.IPNet{
- IP: net.IP{172, 16, 1, 114},
- Mask: net.IPMask{255, 255, 255, 0},
- }
- unexistentRouteIPv4ToAdd = RouteData{
- Destination: net.IPNet{
- IP: net.IP{172, 16, 200, 0},
- Mask: net.IPMask{255, 255, 255, 0},
- },
- NextHop: net.IP{172, 16, 1, 2},
- Metric: 0,
- }
- dnsesToSet = []net.IP{
- net.IPv4(8, 8, 8, 8),
- net.IPv4(8, 8, 4, 4),
+ nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24")
+ nonexistentRouteIPv4ToAdd = RouteData{
+ Destination: netip.MustParsePrefix("172.16.200.0/24"),
+ NextHop: netip.MustParseAddr("172.16.1.2"),
+ Metric: 0,
}
+ dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}
)
func runningElevated() bool {
@@ -380,9 +371,9 @@ func TestAddDeleteIPAddress(t *testing.T) {
return
}
- addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String())
return
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
@@ -410,7 +401,7 @@ func TestAddDeleteIPAddress(t *testing.T) {
for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next {
count--
}
- err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd})
+ err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd})
if err != nil {
t.Errorf("LUID.AddIPAddresses() returned an error: %w", err)
}
@@ -424,26 +415,26 @@ func TestAddDeleteIPAddress(t *testing.T) {
if count != 1 {
t.Errorf("After adding there are %d new interface(s).", count)
}
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err != nil {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
} else if addr == nil {
- t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String())
}
if !created {
t.Errorf("Notification handler has not been called on add.")
}
- err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd)
+ err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd)
if err != nil {
t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String())
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
}
@@ -460,14 +451,13 @@ func TestGetRoutes(t *testing.T) {
}
func TestAddDeleteRoute(t *testing.T) {
- findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) {
+ findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) {
var family AddressFamily
- switch {
- case dest.IP.To4() != nil:
+ if dest.Addr().Is4() {
family = windows.AF_INET
- case dest.IP.To16() != nil:
+ } else if dest.Addr().Is6() {
family = windows.AF_INET6
- default:
+ } else {
return nil, windows.ERROR_INVALID_PARAMETER
}
r, err := GetIPForwardTable2(family)
@@ -475,9 +465,8 @@ func TestAddDeleteRoute(t *testing.T) {
return nil, err
}
matches := make([]MibIPforwardRow2, 0, len(r))
- ones, _ := dest.Mask.Size()
for _, route := range r {
- if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) {
+ if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() {
matches = append(matches, route)
}
}
@@ -494,20 +483,20 @@ func TestAddDeleteRoute(t *testing.T) {
return
}
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
- t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?")
+ t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?")
return
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.Route() returned an error: %w", err)
return
}
- routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
- t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes))
+ t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes))
}
var created, deleted bool
@@ -524,42 +513,42 @@ func TestAddDeleteRoute(t *testing.T) {
} else {
defer cb.Unregister()
}
- err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric)
+ err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric)
if err != nil {
t.Errorf("LUID.AddRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == windows.ERROR_NOT_FOUND {
t.Error("LUID.Route() returned nil although the route is added successfully.")
} else if err != nil {
t.Errorf("LUID.Route() returned an error: %w", err)
- } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) {
+ } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop {
t.Error("LUID.Route() returned a wrong route!")
}
if !created {
t.Errorf("Route handler has not been called on add.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 1 {
t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1)
- } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) {
- t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String())
+ } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() {
+ t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String())
}
- err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err != nil {
t.Errorf("LUID.DeleteRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
t.Error("LUID.Route() returned a route although it is removed successfully.")
} else if err != windows.ERROR_NOT_FOUND {
@@ -569,7 +558,7 @@ func TestAddDeleteRoute(t *testing.T) {
t.Errorf("Route handler has not been called on delete.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
@@ -606,7 +595,7 @@ func TestFlushDNS(t *testing.T) {
t.Errorf("LUID.DNS() returned an error: %w", err)
}
for _, a := range dns {
- if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) {
+ if a.Is4() {
n++
}
}
@@ -651,7 +640,7 @@ func TestSetDNS(t *testing.T) {
t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes))
} else {
for i := range dnsesToSet {
- if !dnsesToSet[i].Equal(newDNSes[i]) {
+ if dnsesToSet[i] != newDNSes[i] {
t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String())
}
}
diff --git a/ui/editdialog.go b/ui/editdialog.go
index 3b1521ff..4a7383d1 100644
--- a/ui/editdialog.go
+++ b/ui/editdialog.go
@@ -8,6 +8,8 @@ package ui
import (
"strings"
+ "golang.zx2c4.com/go118/netip"
+
"github.com/lxn/walk"
"github.com/lxn/win"
"golang.org/x/sys/windows"
@@ -185,15 +187,17 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() {
return
}
var (
- v40 = [4]byte{}
- v60 = [16]byte{}
- v48 = [4]byte{0x80}
- v68 = [16]byte{0x80}
+ v400 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
+ v600000 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
+ v401 = netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 1)
+ v600001 = netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 1)
+ v41281 = netip.PrefixFrom(netip.AddrFrom4([4]byte{0x80}), 1)
+ v680001 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
)
block := dlg.blockUntunneledTrafficCB.Checked()
cfg, err := conf.FromWgQuick(dlg.syntaxEdit.Text(), "temporary")
- var newAllowedIPs []conf.IPCidr
+ var newAllowedIPs []netip.Prefix
if err != nil {
goto err
@@ -202,7 +206,7 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() {
goto err
}
- newAllowedIPs = make([]conf.IPCidr, 0, len(cfg.Peers[0].AllowedIPs))
+ newAllowedIPs = make([]netip.Prefix, 0, len(cfg.Peers[0].AllowedIPs))
if block {
var (
foundV401 bool
@@ -211,13 +215,13 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() {
foundV680001 bool
)
for _, allowedip := range cfg.Peers[0].AllowedIPs {
- if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
+ if allowedip == v600001 {
foundV600001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) {
+ } else if allowedip == v680001 {
foundV680001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
+ } else if allowedip == v401 {
foundV401 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) {
+ } else if allowedip == v41281 {
foundV41281 = true
} else {
newAllowedIPs = append(newAllowedIPs, allowedip)
@@ -227,44 +231,44 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() {
goto err
}
if foundV401 && foundV41281 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 0})
+ newAllowedIPs = append(newAllowedIPs, v400)
} else if foundV401 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1})
+ newAllowedIPs = append(newAllowedIPs, v401)
} else if foundV41281 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1})
+ newAllowedIPs = append(newAllowedIPs, v41281)
}
if foundV600001 && foundV680001 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 0})
+ newAllowedIPs = append(newAllowedIPs, v600000)
} else if foundV600001 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1})
+ newAllowedIPs = append(newAllowedIPs, v600001)
} else if foundV680001 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1})
+ newAllowedIPs = append(newAllowedIPs, v680001)
}
cfg.Peers[0].AllowedIPs = newAllowedIPs
} else {
var (
- foundV400 bool
- foundV600 bool
+ foundV400 bool
+ foundV600000 bool
)
for _, allowedip := range cfg.Peers[0].AllowedIPs {
- if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
+ if allowedip == v600000 {
+ foundV600000 = true
+ } else if allowedip == v400 {
foundV400 = true
} else {
newAllowedIPs = append(newAllowedIPs, allowedip)
}
}
- if !(foundV400 || foundV600) {
+ if !(foundV400 || foundV600000) {
goto err
}
if foundV400 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1})
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1})
+ newAllowedIPs = append(newAllowedIPs, v401)
+ newAllowedIPs = append(newAllowedIPs, v41281)
}
- if foundV600 {
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1})
- newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1})
+ if foundV600000 {
+ newAllowedIPs = append(newAllowedIPs, v600001)
+ newAllowedIPs = append(newAllowedIPs, v680001)
}
cfg.Peers[0].AllowedIPs = newAllowedIPs
}