aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTyler Kropp <kropptyler@gmail.com>2020-03-02 19:41:28 -0500
committerDavid Crawshaw <david@zentus.com>2020-03-31 09:32:57 +1100
commitc7bb15a70df5cfc949c836429b5e39ce57d047f9 (patch)
tree09d98c2464f5af1bcf58d41ef7f1fd79f7d3d769
parentwgcfg: new config package (diff)
downloadwireguard-go-c7bb15a70df5cfc949c836429b5e39ce57d047f9.tar.xz
wireguard-go-c7bb15a70df5cfc949c836429b5e39ce57d047f9.zip
wgcfg: add fast CIDR.Contains implementation
Signed-off-by: Tyler Kropp <kropptyler@gmail.com>
-rw-r--r--wgcfg/ip.go26
-rw-r--r--wgcfg/ip_test.go118
2 files changed, 142 insertions, 2 deletions
diff --git a/wgcfg/ip.go b/wgcfg/ip.go
index ecf5faf..7541d18 100644
--- a/wgcfg/ip.go
+++ b/wgcfg/ip.go
@@ -2,6 +2,7 @@ package wgcfg
import (
"fmt"
+ "math"
"net"
)
@@ -106,12 +107,33 @@ func (r *CIDR) IPNet() *net.IPNet {
}
return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)}
}
+
func (r *CIDR) Contains(ip *IP) bool {
if r == nil || ip == nil {
return false
}
- // TODO: this isn't hard, write a more efficient implementation.
- return r.IPNet().Contains(ip.IP())
+ c := int8(r.Mask)
+ i := 0
+ if r.IP.Is4() {
+ i = 12
+ if ip.Is6() {
+ return false
+ }
+ }
+ for ; i < 16 && c > 0; i++ {
+ var x uint8
+ if c < 8 {
+ x = 8 - uint8(c)
+ }
+ m := uint8(math.MaxUint8) >> x << x
+ a := r.IP.Addr[i] & m
+ b := ip.Addr[i] & m
+ if a != b {
+ return false
+ }
+ c -= 8
+ }
+ return true
}
func (r CIDR) MarshalText() ([]byte, error) {
diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go
new file mode 100644
index 0000000..d3682bb
--- /dev/null
+++ b/wgcfg/ip_test.go
@@ -0,0 +1,118 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg_test
+
+import (
+ "testing"
+
+ "golang.zx2c4.com/wireguard/wgcfg"
+)
+
+func TestCIDRContains(t *testing.T) {
+ t.Run("home router test", func(t *testing.T) {
+ r, err := wgcfg.ParseCIDR("192.168.0.0/24")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("192.168.0.1")
+ if ip == nil {
+ t.Fatalf("address failed to parse")
+ }
+ if !r.Contains(ip) {
+ t.Fatalf("'%s' should contain '%s'", r, ip)
+ }
+ })
+
+ t.Run("IPv4 outside network", func(t *testing.T) {
+ r, err := wgcfg.ParseCIDR("192.168.0.0/30")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("192.168.0.4")
+ if ip == nil {
+ t.Fatalf("address failed to parse")
+ }
+ if r.Contains(ip) {
+ t.Fatalf("'%s' should not contain '%s'", r, ip)
+ }
+ })
+
+ t.Run("IPv4 does not contain IPv6", func(t *testing.T) {
+ r, err := wgcfg.ParseCIDR("192.168.0.0/24")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334")
+ if ip == nil {
+ t.Fatalf("address failed to parse")
+ }
+ if r.Contains(ip) {
+ t.Fatalf("'%s' should not contain '%s'", r, ip)
+ }
+ })
+
+ t.Run("IPv6 inside network", func(t *testing.T) {
+ r, err := wgcfg.ParseCIDR("2001:db8:1234::/48")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
+ if ip == nil {
+ t.Fatalf("ParseIP returned nil pointer")
+ }
+ if !r.Contains(ip) {
+ t.Fatalf("'%s' should not contain '%s'", r, ip)
+ }
+ })
+
+ t.Run("IPv6 outside network", func(t *testing.T) {
+ r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4")
+ if ip == nil {
+ t.Fatalf("ParseIP returned nil pointer")
+ }
+ if r.Contains(ip) {
+ t.Fatalf("'%s' should not contain '%s'", r, ip)
+ }
+ })
+}
+
+func BenchmarkCIDRContainsIPv4(b *testing.B) {
+ b.Run("IPv4", func(b *testing.B) {
+ r, err := wgcfg.ParseCIDR("192.168.1.0/24")
+ if err != nil {
+ b.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("1.2.3.4")
+ if ip == nil {
+ b.Fatalf("ParseIP returned nil pointer")
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ r.Contains(ip)
+ }
+ })
+
+ b.Run("IPv6", func(b *testing.B) {
+ r, err := wgcfg.ParseCIDR("2001:db8:1234::/48")
+ if err != nil {
+ b.Fatal(err)
+ }
+ ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
+ if ip == nil {
+ b.Fatalf("ParseIP returned nil pointer")
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ r.Contains(ip)
+ }
+ })
+}