aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/allowedips.c9
-rw-r--r--drivers/net/wireguard/device.c44
-rw-r--r--drivers/net/wireguard/netlink.c14
-rw-r--r--drivers/net/wireguard/noise.c45
-rw-r--r--drivers/net/wireguard/peer.c3
-rw-r--r--drivers/net/wireguard/queueing.c3
-rw-r--r--drivers/net/wireguard/queueing.h4
-rw-r--r--drivers/net/wireguard/receive.c9
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c22
-rw-r--r--drivers/net/wireguard/selftest/ratelimiter.c4
-rw-r--r--drivers/net/wireguard/socket.c5
11 files changed, 103 insertions, 59 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 9a4c8ff32d9d..5bf7822c53f1 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -6,6 +6,8 @@
#include "allowedips.h"
#include "peer.h"
+enum { MAX_ALLOWEDIPS_BITS = 128 };
+
static struct kmem_cache *node_cache;
static void swap_endian(u8 *dst, const u8 *src, u8 bits)
@@ -40,7 +42,8 @@ static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len)
{
if (rcu_access_pointer(p)) {
- WARN_ON(IS_ENABLED(DEBUG) && *len >= 128);
+ if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_BITS))
+ return;
stack[(*len)++] = rcu_dereference_raw(p);
}
}
@@ -52,7 +55,7 @@ static void node_free_rcu(struct rcu_head *rcu)
static void root_free_rcu(struct rcu_head *rcu)
{
- struct allowedips_node *node, *stack[128] = {
+ struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = {
container_of(rcu, struct allowedips_node, rcu) };
unsigned int len = 1;
@@ -65,7 +68,7 @@ static void root_free_rcu(struct rcu_head *rcu)
static void root_remove_peer_lists(struct allowedips_node *root)
{
- struct allowedips_node *node, *stack[128] = { root };
+ struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = { root };
unsigned int len = 1;
while (len > 0 && (node = stack[--len])) {
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index a46067c38bf5..d58e9f818d3b 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -19,6 +19,7 @@
#include <linux/if_arp.h>
#include <linux/icmp.h>
#include <linux/suspend.h>
+#include <net/dst_metadata.h>
#include <net/icmp.h>
#include <net/rtnetlink.h>
#include <net/ip_tunnels.h>
@@ -59,9 +60,7 @@ out:
return ret;
}
-#ifdef CONFIG_PM_SLEEP
-static int wg_pm_notification(struct notifier_block *nb, unsigned long action,
- void *data)
+static int wg_pm_notification(struct notifier_block *nb, unsigned long action, void *data)
{
struct wg_device *wg;
struct wg_peer *peer;
@@ -70,7 +69,8 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action,
* its normal operation rather than as a somewhat rare event, then we
* don't actually want to clear keys.
*/
- if (IS_ENABLED(CONFIG_PM_AUTOSLEEP) || IS_ENABLED(CONFIG_ANDROID))
+ if (IS_ENABLED(CONFIG_PM_AUTOSLEEP) ||
+ IS_ENABLED(CONFIG_PM_USERSPACE_AUTOSLEEP))
return 0;
if (action != PM_HIBERNATION_PREPARE && action != PM_SUSPEND_PREPARE)
@@ -92,7 +92,24 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action,
}
static struct notifier_block pm_notifier = { .notifier_call = wg_pm_notification };
-#endif
+
+static int wg_vm_notification(struct notifier_block *nb, unsigned long action, void *data)
+{
+ struct wg_device *wg;
+ struct wg_peer *peer;
+
+ rtnl_lock();
+ list_for_each_entry(wg, &device_list, device_list) {
+ mutex_lock(&wg->device_update_lock);
+ list_for_each_entry(peer, &wg->peer_list, peer_list)
+ wg_noise_expire_current_peer_keypairs(peer);
+ mutex_unlock(&wg->device_update_lock);
+ }
+ rtnl_unlock();
+ return 0;
+}
+
+static struct notifier_block vm_notifier = { .notifier_call = wg_vm_notification };
static int wg_stop(struct net_device *dev)
{
@@ -152,7 +169,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
goto err_peer;
}
- mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu;
+ mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu;
__skb_queue_head_init(&packets);
if (!skb_is_gso(skb)) {
@@ -424,16 +441,18 @@ int __init wg_device_init(void)
{
int ret;
-#ifdef CONFIG_PM_SLEEP
ret = register_pm_notifier(&pm_notifier);
if (ret)
return ret;
-#endif
- ret = register_pernet_device(&pernet_ops);
+ ret = register_random_vmfork_notifier(&vm_notifier);
if (ret)
goto error_pm;
+ ret = register_pernet_device(&pernet_ops);
+ if (ret)
+ goto error_vm;
+
ret = rtnl_link_register(&link_ops);
if (ret)
goto error_pernet;
@@ -442,10 +461,10 @@ int __init wg_device_init(void)
error_pernet:
unregister_pernet_device(&pernet_ops);
+error_vm:
+ unregister_random_vmfork_notifier(&vm_notifier);
error_pm:
-#ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier);
-#endif
return ret;
}
@@ -453,8 +472,7 @@ void wg_device_uninit(void)
{
rtnl_link_unregister(&link_ops);
unregister_pernet_device(&pernet_ops);
-#ifdef CONFIG_PM_SLEEP
+ unregister_random_vmfork_notifier(&vm_notifier);
unregister_pm_notifier(&pm_notifier);
-#endif
rcu_barrier();
}
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index d0f3b6d7f408..43c8c84e7ea8 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -436,14 +436,13 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
if (attrs[WGPEER_A_ENDPOINT]) {
struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
+ struct endpoint endpoint = { { { 0 } } };
- if ((len == sizeof(struct sockaddr_in) &&
- addr->sa_family == AF_INET) ||
- (len == sizeof(struct sockaddr_in6) &&
- addr->sa_family == AF_INET6)) {
- struct endpoint endpoint = { { { 0 } } };
-
- memcpy(&endpoint.addr, addr, len);
+ if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) {
+ endpoint.addr4 = *(struct sockaddr_in *)addr;
+ wg_socket_set_peer_endpoint(peer, &endpoint);
+ } else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) {
+ endpoint.addr6 = *(struct sockaddr_in6 *)addr;
wg_socket_set_peer_endpoint(peer, &endpoint);
}
}
@@ -621,6 +620,7 @@ static const struct genl_ops genl_ops[] = {
static struct genl_family genl_family __ro_after_init = {
.ops = genl_ops,
.n_ops = ARRAY_SIZE(genl_ops),
+ .resv_start_op = WG_CMD_SET_DEVICE + 1,
.name = WG_GENL_NAME,
.version = WG_GENL_VERSION,
.maxattr = WGDEVICE_A_MAX,
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index c0cfd9b36c0b..720952b92e78 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -302,6 +302,41 @@ void wg_noise_set_static_identity_private_key(
static_identity->static_public, private_key);
}
+static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
+{
+ struct blake2s_state state;
+ u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
+ u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
+ int i;
+
+ if (keylen > BLAKE2S_BLOCK_SIZE) {
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, key, keylen);
+ blake2s_final(&state, x_key);
+ } else
+ memcpy(x_key, key, keylen);
+
+ for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
+ x_key[i] ^= 0x36;
+
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&state, in, inlen);
+ blake2s_final(&state, i_hash);
+
+ for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
+ x_key[i] ^= 0x5c ^ 0x36;
+
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
+ blake2s_final(&state, i_hash);
+
+ memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
+ memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
+ memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
+}
+
/* This is Hugo Krawczyk's HKDF:
* - https://eprint.iacr.org/2010/264.pdf
* - https://tools.ietf.org/html/rfc5869
@@ -322,14 +357,14 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
((third_len || third_dst) && (!second_len || !second_dst))));
/* Extract entropy from data into secret */
- blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
+ hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
if (!first_dst || !first_len)
goto out;
/* Expand first key: key = secret, data = 0x1 */
output[0] = 1;
- blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
+ hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
memcpy(first_dst, output, first_len);
if (!second_dst || !second_len)
@@ -337,8 +372,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
/* Expand second key: key = secret, data = first-key || 0x2 */
output[BLAKE2S_HASH_SIZE] = 2;
- blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
- BLAKE2S_HASH_SIZE);
+ hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
memcpy(second_dst, output, second_len);
if (!third_dst || !third_len)
@@ -346,8 +380,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
/* Expand third key: key = secret, data = second-key || 0x3 */
output[BLAKE2S_HASH_SIZE] = 3;
- blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
- BLAKE2S_HASH_SIZE);
+ hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
memcpy(third_dst, output, third_len);
out:
diff --git a/drivers/net/wireguard/peer.c b/drivers/net/wireguard/peer.c
index 1acd00ab2fbc..1cb502a932e0 100644
--- a/drivers/net/wireguard/peer.c
+++ b/drivers/net/wireguard/peer.c
@@ -54,8 +54,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
skb_queue_head_init(&peer->staged_packet_queue);
wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state);
- netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll,
- NAPI_POLL_WEIGHT);
+ netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll);
napi_enable(&peer->napi);
list_add_tail(&peer->peer_list, &wg->peer_list);
INIT_LIST_HEAD(&peer->allowedips_list);
diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c
index 1de413b19e34..8084e7408c0a 100644
--- a/drivers/net/wireguard/queueing.c
+++ b/drivers/net/wireguard/queueing.c
@@ -4,6 +4,7 @@
*/
#include "queueing.h"
+#include <linux/skb_array.h>
struct multicore_worker __percpu *
wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr)
@@ -42,7 +43,7 @@ void wg_packet_queue_free(struct crypt_queue *queue, bool purge)
{
free_percpu(queue->worker);
WARN_ON(!purge && !__ptr_ring_empty(&queue->ring));
- ptr_ring_cleanup(&queue->ring, purge ? (void(*)(void*))kfree_skb : NULL);
+ ptr_ring_cleanup(&queue->ring, purge ? __skb_array_destroy_skb : NULL);
}
#define NEXT(skb) ((skb)->prev)
diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h
index e2388107f7fd..583adb37ee1e 100644
--- a/drivers/net/wireguard/queueing.h
+++ b/drivers/net/wireguard/queueing.h
@@ -79,9 +79,7 @@ static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
u8 sw_hash = skb->sw_hash;
u32 hash = skb->hash;
skb_scrub_packet(skb, true);
- memset(&skb->headers_start, 0,
- offsetof(struct sk_buff, headers_end) -
- offsetof(struct sk_buff, headers_start));
+ memset(&skb->headers, 0, sizeof(skb->headers));
if (encapsulating) {
skb->l4_hash = l4_hash;
skb->sw_hash = sw_hash;
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index 7b8df406c773..7135d51d2d87 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -19,15 +19,8 @@
/* Must be called with bh disabled. */
static void update_rx_stats(struct wg_peer *peer, size_t len)
{
- struct pcpu_sw_netstats *tstats =
- get_cpu_ptr(peer->device->dev->tstats);
-
- u64_stats_update_begin(&tstats->syncp);
- ++tstats->rx_packets;
- tstats->rx_bytes += len;
+ dev_sw_netstats_rx_add(peer->device->dev, len);
peer->rx_bytes += len;
- u64_stats_update_end(&tstats->syncp);
- put_cpu_ptr(tstats);
}
#define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
index e173204ae7d7..19eac00b2381 100644
--- a/drivers/net/wireguard/selftest/allowedips.c
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -284,7 +284,7 @@ static __init bool randomized_test(void)
mutex_lock(&mutex);
for (i = 0; i < NUM_RAND_ROUTES; ++i) {
- prandom_bytes(ip, 4);
+ get_random_bytes(ip, 4);
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
@@ -299,7 +299,7 @@ static __init bool randomized_test(void)
}
for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
memcpy(mutated, ip, 4);
- prandom_bytes(mutate_mask, 4);
+ get_random_bytes(mutate_mask, 4);
mutate_amount = prandom_u32_max(32);
for (k = 0; k < mutate_amount / 8; ++k)
mutate_mask[k] = 0xff;
@@ -310,7 +310,7 @@ static __init bool randomized_test(void)
for (k = 0; k < 4; ++k)
mutated[k] = (mutated[k] & mutate_mask[k]) |
(~mutate_mask[k] &
- prandom_u32_max(256));
+ get_random_u8());
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
if (wg_allowedips_insert_v4(&t,
@@ -328,7 +328,7 @@ static __init bool randomized_test(void)
}
for (i = 0; i < NUM_RAND_ROUTES; ++i) {
- prandom_bytes(ip, 16);
+ get_random_bytes(ip, 16);
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
@@ -343,7 +343,7 @@ static __init bool randomized_test(void)
}
for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
memcpy(mutated, ip, 16);
- prandom_bytes(mutate_mask, 16);
+ get_random_bytes(mutate_mask, 16);
mutate_amount = prandom_u32_max(128);
for (k = 0; k < mutate_amount / 8; ++k)
mutate_mask[k] = 0xff;
@@ -354,7 +354,7 @@ static __init bool randomized_test(void)
for (k = 0; k < 4; ++k)
mutated[k] = (mutated[k] & mutate_mask[k]) |
(~mutate_mask[k] &
- prandom_u32_max(256));
+ get_random_u8());
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
if (wg_allowedips_insert_v6(&t,
@@ -381,13 +381,13 @@ static __init bool randomized_test(void)
for (j = 0;; ++j) {
for (i = 0; i < NUM_QUERIES; ++i) {
- prandom_bytes(ip, 4);
+ get_random_bytes(ip, 4);
if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
pr_err("allowedips random v4 self-test: FAIL\n");
goto free;
}
- prandom_bytes(ip, 16);
+ get_random_bytes(ip, 16);
if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
pr_err("allowedips random v6 self-test: FAIL\n");
goto free;
@@ -593,10 +593,10 @@ bool __init wg_allowedips_selftest(void)
wg_allowedips_remove_by_peer(&t, a, &mutex);
test_negative(4, a, 192, 168, 0, 1);
- /* These will hit the WARN_ON(len >= 128) in free_node if something
- * goes wrong.
+ /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_BITS) in free_node
+ * if something goes wrong.
*/
- for (i = 0; i < 128; ++i) {
+ for (i = 0; i < MAX_ALLOWEDIPS_BITS; ++i) {
part = cpu_to_be64(~(1LLU << (i % 64)));
memset(&ip, 0xff, 16);
memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
diff --git a/drivers/net/wireguard/selftest/ratelimiter.c b/drivers/net/wireguard/selftest/ratelimiter.c
index 007cd4457c5f..d4bb40a695ab 100644
--- a/drivers/net/wireguard/selftest/ratelimiter.c
+++ b/drivers/net/wireguard/selftest/ratelimiter.c
@@ -167,7 +167,7 @@ bool __init wg_ratelimiter_selftest(void)
++test;
#endif
- for (trials = TRIALS_BEFORE_GIVING_UP;;) {
+ for (trials = TRIALS_BEFORE_GIVING_UP; IS_ENABLED(DEBUG_RATELIMITER_TIMINGS);) {
int test_count = 0, ret;
ret = timings_test(skb4, hdr4, skb6, hdr6, &test_count);
@@ -176,7 +176,6 @@ bool __init wg_ratelimiter_selftest(void)
test += test_count;
goto err;
}
- msleep(500);
continue;
} else if (ret < 0) {
test += test_count;
@@ -195,7 +194,6 @@ bool __init wg_ratelimiter_selftest(void)
test += test_count;
goto err;
}
- msleep(50);
continue;
}
test += test_count;
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index 6f07b949cb81..0414d7a6ce74 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -160,6 +160,7 @@ out:
rcu_read_unlock_bh();
return ret;
#else
+ kfree_skb(skb);
return -EAFNOSUPPORT;
#endif
}
@@ -241,7 +242,7 @@ int wg_socket_endpoint_from_skb(struct endpoint *endpoint,
endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr;
endpoint->src4.s_addr = ip_hdr(skb)->daddr;
endpoint->src_if4 = skb->skb_iif;
- } else if (skb->protocol == htons(ETH_P_IPV6)) {
+ } else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) {
endpoint->addr6.sin6_family = AF_INET6;
endpoint->addr6.sin6_port = udp_hdr(skb)->source;
endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr;
@@ -284,7 +285,7 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer,
peer->endpoint.addr4 = endpoint->addr4;
peer->endpoint.src4 = endpoint->src4;
peer->endpoint.src_if4 = endpoint->src_if4;
- } else if (endpoint->addr.sa_family == AF_INET6) {
+ } else if (IS_ENABLED(CONFIG_IPV6) && endpoint->addr.sa_family == AF_INET6) {
peer->endpoint.addr6 = endpoint->addr6;
peer->endpoint.src6 = endpoint->src6;
} else {