aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/device.c2
-rw-r--r--src/queueing.h8
-rw-r--r--src/receive.c4
3 files changed, 10 insertions, 4 deletions
diff --git a/src/device.c b/src/device.c
index 73c892a..5ea039b 100644
--- a/src/device.c
+++ b/src/device.c
@@ -130,7 +130,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
u32 mtu;
int ret;
- if (unlikely(wg_skb_examine_untrusted_ip_hdr(skb) != skb->protocol)) {
+ if (unlikely(!wg_check_packet_protocol(skb))) {
ret = -EPROTONOSUPPORT;
net_dbg_ratelimited("%s: Invalid IP packet\n", dev->name);
goto err;
diff --git a/src/queueing.h b/src/queueing.h
index e49a464..256c6be 100644
--- a/src/queueing.h
+++ b/src/queueing.h
@@ -66,7 +66,7 @@ struct packet_cb {
#define PACKET_PEER(skb) (PACKET_CB(skb)->keypair->entry.peer)
/* Returns either the correct skb->protocol value, or 0 if invalid. */
-static inline __be16 wg_skb_examine_untrusted_ip_hdr(struct sk_buff *skb)
+static inline __be16 wg_examine_packet_protocol(struct sk_buff *skb)
{
if (skb_network_header(skb) >= skb->head &&
(skb_network_header(skb) + sizeof(struct iphdr)) <=
@@ -81,6 +81,12 @@ static inline __be16 wg_skb_examine_untrusted_ip_hdr(struct sk_buff *skb)
return 0;
}
+static inline bool wg_check_packet_protocol(struct sk_buff *skb)
+{
+ __be16 real_protocol = wg_examine_packet_protocol(skb);
+ return real_protocol && skb->protocol == real_protocol;
+}
+
static inline void wg_reset_packet(struct sk_buff *skb)
{
const int pfmemalloc = skb->pfmemalloc;
diff --git a/src/receive.c b/src/receive.c
index a94fcd7..dde4109 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -57,7 +57,7 @@ static int prepare_skb_header(struct sk_buff *skb, struct wg_device *wg)
size_t data_offset, data_len, header_len;
struct udphdr *udp;
- if (unlikely(wg_skb_examine_untrusted_ip_hdr(skb) != skb->protocol ||
+ if (unlikely(!wg_check_packet_protocol(skb) ||
skb_transport_header(skb) < skb->head ||
(skb_transport_header(skb) + sizeof(struct udphdr)) >
skb_tail_pointer(skb)))
@@ -392,7 +392,7 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
#ifndef COMPAT_CANNOT_USE_CSUM_LEVEL
skb->csum_level = ~0; /* All levels */
#endif
- skb->protocol = wg_skb_examine_untrusted_ip_hdr(skb);
+ skb->protocol = wg_examine_packet_protocol(skb);
if (skb->protocol == htons(ETH_P_IP)) {
len = ntohs(ip_hdr(skb)->tot_len);
if (unlikely(len < sizeof(struct iphdr)))