aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go77
1 files changed, 61 insertions, 16 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index f0e954b..058aca5 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -17,6 +17,7 @@ import (
"regexp"
"strconv"
"strings"
+ "sync"
"time"
"golang.zx2c4.com/go118/netip"
@@ -285,11 +286,13 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
}
type PingConn struct {
- laddr PingAddr
- raddr PingAddr
- wq waiter.Queue
- ep tcpip.Endpoint
- deadline time.Time
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ mu sync.RWMutex
+ deadline time.Time
+ deadlineBreaker chan struct{}
}
type PingAddr struct{ addr netip.Addr }
@@ -307,6 +310,20 @@ func (ia PingAddr) Network() string {
return "ping"
}
+func PingAddrFromAddr(addr net.Addr) (PingAddr, error) {
+ switch v := addr.(type) {
+ case PingAddr:
+ return v, nil
+
+ case *net.IPAddr:
+ nip := netip.AddrFromSlice(v.IP)
+ return PingAddr{nip}, nil
+
+ default:
+ return PingAddr{}, errors.New("wrong address format")
+ }
+}
+
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
v6 := laddr.Is6() || raddr.Is6()
bind := laddr.IsValid()
@@ -325,7 +342,10 @@ func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
pn = ipv6.ProtocolNumber
}
- pc := &PingConn{laddr: PingAddr{laddr}}
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadlineBreaker: make(chan struct{}, 1),
+ }
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
if tcpipErr != nil {
@@ -360,6 +380,7 @@ func (pc *PingConn) RemoteAddr() net.Addr {
}
func (pc *PingConn) Close() error {
+ close(pc.deadlineBreaker)
pc.ep.Close()
return nil
}
@@ -369,8 +390,11 @@ func (pc *PingConn) SetWriteDeadline(t time.Time) error {
}
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
- ia, ok := addr.(PingAddr)
- if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
+ ia, err := PingAddrFromAddr(addr)
+ if err != nil {
+ return 0, fmt.Errorf("ping write: %w", err)
+ }
+ if !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
@@ -409,15 +433,32 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
pc.wq.EventRegister(&e, waiter.EventIn)
defer pc.wq.EventUnregister(&e)
- deadline := pc.deadline
+ ready := false
- if deadline.IsZero() {
- <-notifyCh
- } else {
- select {
- case <-time.NewTimer(deadline.Sub(time.Now())).C:
- return 0, nil, os.ErrDeadlineExceeded
- case <-notifyCh:
+ for !ready {
+ pc.mu.RLock()
+ deadlineBreaker := pc.deadlineBreaker
+ deadline := pc.deadline
+ pc.mu.RUnlock()
+
+ if deadline.IsZero() {
+ select {
+ case <-deadlineBreaker:
+ case <-notifyCh:
+ ready = true
+ }
+ } else {
+ t := time.NewTimer(deadline.Sub(time.Now()))
+ defer t.Stop()
+
+ select {
+ case <-t.C:
+ return 0, nil, os.ErrDeadlineExceeded
+
+ case <-deadlineBreaker:
+ case <-notifyCh:
+ ready = true
+ }
}
}
@@ -452,6 +493,10 @@ func (pc *PingConn) SetDeadline(t time.Time) error {
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.mu.Lock()
+ defer pc.mu.Unlock()
+ close(pc.deadlineBreaker)
+ pc.deadlineBreaker = make(chan struct{}, 1)
pc.deadline = t
return nil
}