summaryrefslogtreecommitdiffstatshomepage
path: root/src/receive.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/receive.c')
-rw-r--r--src/receive.c91
1 files changed, 49 insertions, 42 deletions
diff --git a/src/receive.c b/src/receive.c
index 5707ab2..f791a2e 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -30,9 +30,11 @@ static inline void update_latest_addr(struct wireguard_peer *peer, struct sk_buf
socket_set_peer_endpoint(peer, &endpoint);
}
-static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size_t *data_len)
+static inline int skb_prepare_header(struct sk_buff *skb)
{
struct udphdr *udp;
+ size_t data_offset, data_len;
+ enum message_type message_type;
if (unlikely(skb->len < sizeof(struct iphdr)))
return -EINVAL;
@@ -42,35 +44,50 @@ static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size
return -EINVAL;
udp = udp_hdr(skb);
- *data_offset = (u8 *)udp - skb->data;
- if (unlikely(*data_offset > U16_MAX)) {
+ data_offset = (u8 *)udp - skb->data;
+ if (unlikely(data_offset > U16_MAX)) {
net_dbg_skb_ratelimited("Packet has offset at impossible location from %pISpfsc\n", skb);
return -EINVAL;
}
- if (unlikely(*data_offset + sizeof(struct udphdr) > skb->len)) {
+ if (unlikely(data_offset + sizeof(struct udphdr) > skb->len)) {
net_dbg_skb_ratelimited("Packet isn't big enough to have UDP fields from %pISpfsc\n", skb);
return -EINVAL;
}
- *data_len = ntohs(udp->len);
- if (unlikely(*data_len < sizeof(struct udphdr))) {
+ data_len = ntohs(udp->len);
+ if (unlikely(data_len < sizeof(struct udphdr))) {
net_dbg_skb_ratelimited("UDP packet is reporting too small of a size from %pISpfsc\n", skb);
return -EINVAL;
}
- if (unlikely(*data_len > skb->len - *data_offset)) {
+ if (unlikely(data_len > skb->len - data_offset)) {
net_dbg_skb_ratelimited("UDP packet is lying about its size from %pISpfsc\n", skb);
return -EINVAL;
}
- *data_len -= sizeof(struct udphdr);
- *data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
- if (!pskb_may_pull(skb, *data_offset + sizeof(struct message_header))) {
+ data_len -= sizeof(struct udphdr);
+ data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
+ if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)))) {
net_dbg_skb_ratelimited("Could not pull header into data section from %pISpfsc\n", skb);
return -EINVAL;
}
-
- return 0;
+ if (pskb_trim(skb, data_len + data_offset) < 0) {
+ net_dbg_skb_ratelimited("Could not trim packet from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ skb_pull(skb, data_offset);
+ if (unlikely(skb->len != data_len)) {
+ net_dbg_skb_ratelimited("Final len does not agree with calculated len from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ message_type = message_determine_type(skb);
+ __skb_push(skb, data_offset);
+ if (unlikely(!pskb_may_pull(skb, data_offset + message_header_sizes[message_type]))) {
+ net_dbg_skb_ratelimited("Could not pull full header into data section from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ __skb_pull(skb, data_offset);
+ return message_type;
}
-static void receive_handshake_packet(struct wireguard_device *wg, void *data, size_t len, struct sk_buff *skb)
+static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb)
{
struct wireguard_peer *peer = NULL;
enum message_type message_type;
@@ -78,16 +95,16 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
enum cookie_mac_state mac_state;
bool packet_needs_cookie;
- message_type = message_determine_type(data, len);
+ message_type = message_determine_type(skb);
if (message_type == MESSAGE_HANDSHAKE_COOKIE) {
net_dbg_skb_ratelimited("Receiving cookie response from %pISpfsc\n", skb);
- cookie_message_consume(data, wg);
+ cookie_message_consume((struct message_handshake_cookie *)skb->data, wg);
return;
}
under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 2;
- mac_state = cookie_validate_packet(&wg->cookie_checker, skb, data, len, under_load);
+ mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load);
if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
packet_needs_cookie = false;
else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
@@ -98,13 +115,13 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
}
switch (message_type) {
- case MESSAGE_HANDSHAKE_INITIATION:
+ case MESSAGE_HANDSHAKE_INITIATION: {
+ struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data;
if (packet_needs_cookie) {
- struct message_handshake_initiation *message = data;
- packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index);
+ packet_send_handshake_cookie(wg, skb, message->sender_index);
return;
}
- peer = noise_handshake_consume_initiation(data, wg);
+ peer = noise_handshake_consume_initiation(message, wg);
if (unlikely(!peer)) {
net_dbg_skb_ratelimited("Invalid handshake initiation from %pISpfsc\n", skb);
return;
@@ -113,13 +130,14 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
net_dbg_ratelimited("Receiving handshake initiation from peer %Lu (%pISpfsc)\n", peer->internal_id, &peer->endpoint.addr);
packet_send_handshake_response(peer);
break;
- case MESSAGE_HANDSHAKE_RESPONSE:
+ }
+ case MESSAGE_HANDSHAKE_RESPONSE: {
+ struct message_handshake_response *message = (struct message_handshake_response *)skb->data;
if (packet_needs_cookie) {
- struct message_handshake_response *message = data;
- packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index);
+ packet_send_handshake_cookie(wg, skb, message->sender_index);
return;
}
- peer = noise_handshake_consume_response(data, wg);
+ peer = noise_handshake_consume_response(message, wg);
if (unlikely(!peer)) {
net_dbg_skb_ratelimited("Invalid handshake response from %pISpfsc\n", skb);
return;
@@ -137,6 +155,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
packet_send_keepalive(peer);
}
break;
+ }
default:
WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
return;
@@ -144,7 +163,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
BUG_ON(!peer);
- rx_stats(peer, len);
+ rx_stats(peer, skb->len);
timers_any_authenticated_packet_received(peer);
timers_any_authenticated_packet_traversal(peer);
peer_put(peer);
@@ -154,12 +173,10 @@ void packet_process_queued_handshake_packets(struct work_struct *work)
{
struct wireguard_device *wg = container_of(work, struct wireguard_device, incoming_handshakes_work);
struct sk_buff *skb;
- size_t len, offset;
size_t num_processed = 0;
while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) {
- if (!skb_data_offset(skb, &offset, &len))
- receive_handshake_packet(wg, skb->data + offset, len, skb);
+ receive_handshake_packet(wg, skb);
dev_kfree_skb(skb);
if (++num_processed == MAX_BURST_INCOMING_HANDSHAKES) {
queue_work(wg->workqueue, &wg->incoming_handshakes_work);
@@ -188,11 +205,6 @@ static void keep_key_fresh(struct wireguard_peer *peer)
}
}
-struct packet_cb {
- u8 ds;
-};
-#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
-
static void receive_data_packet(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key, int err)
{
struct net_device *dev;
@@ -276,11 +288,10 @@ continue_processing:
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
{
- size_t len, offset;
-
- if (unlikely(skb_data_offset(skb, &offset, &len) < 0))
+ int message_type = skb_prepare_header(skb);
+ if (unlikely(message_type < 0))
goto err;
- switch (message_determine_type(skb->data + offset, len)) {
+ switch (message_type) {
case MESSAGE_HANDSHAKE_INITIATION:
case MESSAGE_HANDSHAKE_RESPONSE:
case MESSAGE_HANDSHAKE_COOKIE:
@@ -288,17 +299,13 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
net_dbg_skb_ratelimited("Too many handshakes queued, dropping packet from %pISpfsc\n", skb);
goto err;
}
- if (skb_linearize(skb) < 0) {
- net_dbg_skb_ratelimited("Unable to linearize handshake skb from %pISpfsc\n", skb);
- goto err;
- }
skb_queue_tail(&wg->incoming_handshakes, skb);
/* Queues up a call to packet_process_queued_handshake_packets(skb): */
queue_work(wg->workqueue, &wg->incoming_handshakes_work);
break;
case MESSAGE_DATA:
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
- packet_consume_data(skb, offset, wg, receive_data_packet);
+ packet_consume_data(skb, wg, receive_data_packet);
break;
default:
net_dbg_skb_ratelimited("Invalid packet from %pISpfsc\n", skb);