aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel/firewall/blocker.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/firewall/blocker.go')
-rw-r--r--tunnel/firewall/blocker.go30
1 files changed, 13 insertions, 17 deletions
diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go
index 7da391ca..8a4967ba 100644
--- a/tunnel/firewall/blocker.go
+++ b/tunnel/firewall/blocker.go
@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
import (
"errors"
- "net"
+ "net/netip"
"unsafe"
"golang.org/x/sys/windows"
@@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) {
return bo, nil
}
-func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool) error {
+func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error {
if wfpSession != 0 {
return errors.New("The firewall has already been enabled")
}
@@ -122,26 +122,24 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool
return wrapErr(err)
}
- if len(restrictToDNSServers) > 0 {
- err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14)
- if err != nil {
- return wrapErr(err)
+ if !doNotRestrict {
+ if len(restrictToDNSServers) > 0 {
+ err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14)
+ if err != nil {
+ return wrapErr(err)
+ }
}
- }
- if restrictAll {
err = permitLoopback(session, baseObjects, 13)
if err != nil {
return wrapErr(err)
}
- }
- err = permitTunInterface(session, baseObjects, 12, luid)
- if err != nil {
- return wrapErr(err)
- }
+ err = permitTunInterface(session, baseObjects, 12, luid)
+ if err != nil {
+ return wrapErr(err)
+ }
- if restrictAll {
err = permitDHCPIPv4(session, baseObjects, 12)
if err != nil {
return wrapErr(err)
@@ -164,9 +162,7 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool
return wrapErr(err)
}
*/
- }
- if restrictAll {
err = blockAll(session, baseObjects, 0)
if err != nil {
return wrapErr(err)