From 1d4c21dec9596e0fcacad2b84859261b312ac00e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 4 Mar 2019 03:04:22 +0100 Subject: ifaceconfig: deduplicate routes --- service/ifaceconfig.go | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) (limited to 'service') diff --git a/service/ifaceconfig.go b/service/ifaceconfig.go index eda8fdb0..c6a8257d 100644 --- a/service/ifaceconfig.go +++ b/service/ifaceconfig.go @@ -6,6 +6,7 @@ package service import ( + "bytes" "encoding/binary" "errors" "golang.org/x/sys/windows" @@ -13,7 +14,7 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/windows/conf" "net" - "os" + "sort" "unsafe" ) @@ -157,25 +158,37 @@ func configureInterface(conf *conf.Config, guid *windows.GUID) error { return err } - err = iface.FlushRoutes() - if err != nil { - return nil - } - for _, route := range routes { - err = iface.AddRoute(&route, false) - - //TODO: Ignoring duplicate errors like this maybe isn't very reasonable. - // instead we should make sure we're not adding duplicates ourselves when - // inserting the gateway routes. - if syserr, ok := err.(*os.SyscallError); ok { - if syserr.Err == windows.Errno(ERROR_OBJECT_ALREADY_EXISTS) { - err = nil - } + deduplicatedRoutes := make([]*winipcfg.RouteData, routeCount) + routeCount = 0 + sort.Slice(routes, func(i, j int) bool { + if routes[i].Metric < routes[j].Metric { + return true } - - if err != nil { - return err + if bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 { + return true + } + if bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 { + return true } + if bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 { + return true + } + return false + }) + for i := 0; i < len(routes); i++ { + if i > 0 && routes[i].Metric == routes[i - 1].Metric && + bytes.Equal(routes[i].NextHop, routes[i - 1].NextHop) && + bytes.Equal(routes[i].Destination.IP, routes[i - 1].Destination.IP) && + bytes.Equal(routes[i].Destination.Mask, routes[i - 1].Destination.Mask) { + continue + } + deduplicatedRoutes[routeCount] = &routes[i] + routeCount++ + } + + err = iface.SetRoutes(deduplicatedRoutes, false) + if err != nil { + return nil } err = iface.SetDNS(conf.Interface.Dns) -- cgit v1.2.3-59-g8ed1b