From 72bcd606919eee38379c21a4d870913bb75345db Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Tue, 27 Aug 2019 09:17:19 -0600 Subject: firewall: do not block DNS when no kill-switch --- tunnel/addressconfig.go | 31 +++++++++++++------------ tunnel/firewall/blocker.go | 56 +++++++++++++++++++++------------------------- tunnel/service.go | 12 +++++----- 3 files changed, 49 insertions(+), 50 deletions(-) diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index a1e5dc59..6032d452 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -178,24 +178,27 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t return nil } -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { - restrictAll := false - if len(conf.Peers) == 1 { - nextallowedip: - for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } +func shouldEnableFirewall(conf *conf.Config) bool { + if len(conf.Peers) != 1 { + return false + } +nextallowedip: + for _, allowedip := range conf.Peers[0].AllowedIPs { + if allowedip.Cidr == 0 { + for _, b := range allowedip.IP { + if b != 0 { + continue nextallowedip } - restrictAll = true - break } + return true } } - if restrictAll && len(conf.Interface.DNS) == 0 { + return false +} + +func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { + if len(conf.Interface.DNS) == 0 { log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") } - return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) + return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS) } diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go index 7da391ca..54645d24 100644 --- a/tunnel/firewall/blocker.go +++ b/tunnel/firewall/blocker.go @@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) { return bo, nil } -func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool) error { +func EnableFirewall(luid uint64, restrictToDNSServers []net.IP) error { if wfpSession != 0 { return errors.New("The firewall has already been enabled") } @@ -129,11 +129,9 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool } } - if restrictAll { - err = permitLoopback(session, baseObjects, 13) - if err != nil { - return wrapErr(err) - } + err = permitLoopback(session, baseObjects, 13) + if err != nil { + return wrapErr(err) } err = permitTunInterface(session, baseObjects, 12, luid) @@ -141,36 +139,32 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool return wrapErr(err) } - if restrictAll { - err = permitDHCPIPv4(session, baseObjects, 12) - if err != nil { - return wrapErr(err) - } + err = permitDHCPIPv4(session, baseObjects, 12) + if err != nil { + return wrapErr(err) + } - err = permitDHCPIPv6(session, baseObjects, 12) - if err != nil { - return wrapErr(err) - } + err = permitDHCPIPv6(session, baseObjects, 12) + if err != nil { + return wrapErr(err) + } - err = permitNdp(session, baseObjects, 12) - if err != nil { - return wrapErr(err) - } + err = permitNdp(session, baseObjects, 12) + if err != nil { + return wrapErr(err) + } - /* TODO: actually evaluate if this does anything and if we need this. It's layer 2; our other rules are layer 3. - * In other words, if somebody complains, try enabling it. For now, keep it off. - err = permitHyperV(session, baseObjects, 12) - if err != nil { - return wrapErr(err) - } - */ + /* TODO: actually evaluate if this does anything and if we need this. It's layer 2; our other rules are layer 3. + * In other words, if somebody complains, try enabling it. For now, keep it off. + err = permitHyperV(session, baseObjects, 12) + if err != nil { + return wrapErr(err) } + */ - if restrictAll { - err = blockAll(session, baseObjects, 0) - if err != nil { - return wrapErr(err) - } + err = blockAll(session, baseObjects, 0) + if err != nil { + return wrapErr(err) } return nil diff --git a/tunnel/service.go b/tunnel/service.go index b3699dd5..a9798103 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -164,11 +164,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } nativeTun = wintun.(*tun.NativeTun) - log.Println("Enabling firewall rules") - err = enableFirewall(conf, nativeTun) - if err != nil { - serviceError = services.ErrorFirewall - return + if shouldEnableFirewall(conf) { + log.Println("Enabling firewall rules (\"kill-switch\")") + err = enableFirewall(conf, nativeTun) + if err != nil { + serviceError = services.ErrorFirewall + return + } } log.Println("Dropping privileges") -- cgit v1.2.3-59-g8ed1b