diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-08-13 00:25:13 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-08-13 00:25:13 +0200 |
commit | 9f92a337376794baaff9b7e60e715a0f102549ba (patch) | |
tree | a73c2c3f16a1785a6abdb7abb004c47f35d63a33 | |
parent | embeddable-dll-service: allow falling back to wireguard-go (diff) | |
download | wireguard-windows-9f92a337376794baaff9b7e60e715a0f102549ba.tar.xz wireguard-windows-9f92a337376794baaff9b7e60e715a0f102549ba.zip |
manager: make multiple tunnels mode automatic
Rather than having to set a registry knob to enable multiple tunnels, it
is now automatic. If an additional activated tunnel has the same route
subnets or interface IP addresses as a previous tunnel, that previous
one is stopped. But if there's no overlap, then they coexist.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | conf/config.go | 37 | ||||
-rw-r--r-- | docs/adminregistry.md | 14 | ||||
-rw-r--r-- | manager/ipc_server.go | 69 |
3 files changed, 74 insertions, 46 deletions
diff --git a/conf/config.go b/conf/config.go index a6266851..c35e6664 100644 --- a/conf/config.go +++ b/conf/config.go @@ -94,6 +94,43 @@ func (r *IPCidr) MaskSelf() { } } +func (conf *Config) IntersectsWith(other *Config) bool { + type hashableIPCidr struct { + ip string + cidr byte + } + allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3) + for _, a := range conf.Interface.Addresses { + allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true + a.MaskSelf() + allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true + } + for i := range conf.Peers { + for _, a := range conf.Peers[i].AllowedIPs { + a.MaskSelf() + allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true + } + } + for _, a := range other.Interface.Addresses { + if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] { + return true + } + a.MaskSelf() + if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] { + return true + } + } + for i := range other.Peers { + for _, a := range other.Peers[i].AllowedIPs { + a.MaskSelf() + if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] { + return true + } + } + } + return false +} + func (e *Endpoint) String() string { if strings.IndexByte(e.Host, ':') > 0 { return fmt.Sprintf("[%s]:%d", e.Host, e.Port) diff --git a/docs/adminregistry.md b/docs/adminregistry.md index 34033446..6d8838f6 100644 --- a/docs/adminregistry.md +++ b/docs/adminregistry.md @@ -38,20 +38,6 @@ executing these scripts. > reg add HKLM\Software\WireGuard /v DangerousScriptExecution /t REG_DWORD /d 1 /f ``` -#### `HKLM\Software\WireGuard\MultipleSimultaneousTunnels` - -When this key is set to `DWORD(1)`, the UI may start multiple tunnels at the -same time; otherwise, an existing tunnel is stopped when a new one is started. -Note that it is always possible, regardless of this key, to start multiple -tunnels using `wireguard /installtunnelservice`; this controls only the semantics -of tunnel start requests coming from the UI. If all goes well, this key will be -removed and the logic of whether to stop existing tunnels will be based on -overlapping routes, but for now, this key provides a manual override. - -``` -> reg add HKLM\Software\WireGuard /v MultipleSimultaneousTunnels /t REG_DWORD /d 1 /f -``` - #### `HKLM\Software\WireGuard\ExperimentalKernelDriver` When this key is set to `DWORD(1)`, an experimental kernel driver from the diff --git a/manager/ipc_server.go b/manager/ipc_server.go index fe195094..ba785ced 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -108,44 +108,49 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) } func (s *ManagerService) Start(tunnelName string) error { - // TODO: Rather than being lazy and gating this behind a knob (yuck!), we should instead keep track of the routes - // of each tunnel, and only deactivate in the case of a tunnel with identical routes being added. - if !conf.AdminBool("MultipleSimultaneousTunnels") { - trackedTunnelsLock.Lock() - tt := make([]string, 0, len(trackedTunnels)) - var inTransition string - for t, state := range trackedTunnels { - tt = append(tt, t) - if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { - inTransition = t - break - } + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + + // Figure out which tunnels have intersecting addresses/routes and stop those. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + c2, err := conf.LoadFromName(t) + if err != nil || !c.IntersectsWith(c2) { + // If we can't get the config, assume it doesn't intersect. + continue + } + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break } - trackedTunnelsLock.Unlock() - if len(inTransition) != 0 { - return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition) + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition) + } + + // Stop those intersecting tunnels asynchronously. + go func() { + for _, t := range tt { + s.Stop(t) } - go func() { - for _, t := range tt { + for _, t := range tt { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) s.Stop(t) + time.Sleep(time.Millisecond * 100) } - for _, t := range tt { - state, err := s.State(t) - if err == nil && (state == TunnelStarted || state == TunnelStarting) { - log.Printf("[%s] Trying again to stop zombie tunnel", t) - s.Stop(t) - time.Sleep(time.Millisecond * 100) - } - } - }() - } + } + }() time.AfterFunc(time.Second*10, cleanupStaleNetworkInterfaces) - // After that process is started -- it's somewhat asynchronous -- we install the new one. - c, err := conf.LoadFromName(tunnelName) - if err != nil { - return err - } + // After the stop process has begun, but before it's finished, we install the new one. path, err := c.Path() if err != nil { return err |