diff options
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r-- | drivers/net/wireguard/allowedips.c | 9 | ||||
-rw-r--r-- | drivers/net/wireguard/device.c | 44 | ||||
-rw-r--r-- | drivers/net/wireguard/netlink.c | 14 | ||||
-rw-r--r-- | drivers/net/wireguard/noise.c | 45 | ||||
-rw-r--r-- | drivers/net/wireguard/peer.c | 3 | ||||
-rw-r--r-- | drivers/net/wireguard/queueing.c | 3 | ||||
-rw-r--r-- | drivers/net/wireguard/queueing.h | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/receive.c | 9 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/allowedips.c | 22 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/ratelimiter.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/socket.c | 5 |
11 files changed, 103 insertions, 59 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 9a4c8ff32d9d..5bf7822c53f1 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -6,6 +6,8 @@ #include "allowedips.h" #include "peer.h" +enum { MAX_ALLOWEDIPS_BITS = 128 }; + static struct kmem_cache *node_cache; static void swap_endian(u8 *dst, const u8 *src, u8 bits) @@ -40,7 +42,8 @@ static void push_rcu(struct allowedips_node **stack, struct allowedips_node __rcu *p, unsigned int *len) { if (rcu_access_pointer(p)) { - WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); + if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_BITS)) + return; stack[(*len)++] = rcu_dereference_raw(p); } } @@ -52,7 +55,7 @@ static void node_free_rcu(struct rcu_head *rcu) static void root_free_rcu(struct rcu_head *rcu) { - struct allowedips_node *node, *stack[128] = { + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; @@ -65,7 +68,7 @@ static void root_free_rcu(struct rcu_head *rcu) static void root_remove_peer_lists(struct allowedips_node *root) { - struct allowedips_node *node, *stack[128] = { root }; + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = { root }; unsigned int len = 1; while (len > 0 && (node = stack[--len])) { diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index a46067c38bf5..d58e9f818d3b 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -19,6 +19,7 @@ #include <linux/if_arp.h> #include <linux/icmp.h> #include <linux/suspend.h> +#include <net/dst_metadata.h> #include <net/icmp.h> #include <net/rtnetlink.h> #include <net/ip_tunnels.h> @@ -59,9 +60,7 @@ out: return ret; } -#ifdef CONFIG_PM_SLEEP -static int wg_pm_notification(struct notifier_block *nb, unsigned long action, - void *data) +static int wg_pm_notification(struct notifier_block *nb, unsigned long action, void *data) { struct wg_device *wg; struct wg_peer *peer; @@ -70,7 +69,8 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, * its normal operation rather than as a somewhat rare event, then we * don't actually want to clear keys. */ - if (IS_ENABLED(CONFIG_PM_AUTOSLEEP) || IS_ENABLED(CONFIG_ANDROID)) + if (IS_ENABLED(CONFIG_PM_AUTOSLEEP) || + IS_ENABLED(CONFIG_PM_USERSPACE_AUTOSLEEP)) return 0; if (action != PM_HIBERNATION_PREPARE && action != PM_SUSPEND_PREPARE) @@ -92,7 +92,24 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, } static struct notifier_block pm_notifier = { .notifier_call = wg_pm_notification }; -#endif + +static int wg_vm_notification(struct notifier_block *nb, unsigned long action, void *data) +{ + struct wg_device *wg; + struct wg_peer *peer; + + rtnl_lock(); + list_for_each_entry(wg, &device_list, device_list) { + mutex_lock(&wg->device_update_lock); + list_for_each_entry(peer, &wg->peer_list, peer_list) + wg_noise_expire_current_peer_keypairs(peer); + mutex_unlock(&wg->device_update_lock); + } + rtnl_unlock(); + return 0; +} + +static struct notifier_block vm_notifier = { .notifier_call = wg_vm_notification }; static int wg_stop(struct net_device *dev) { @@ -152,7 +169,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev) goto err_peer; } - mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; + mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; __skb_queue_head_init(&packets); if (!skb_is_gso(skb)) { @@ -424,16 +441,18 @@ int __init wg_device_init(void) { int ret; -#ifdef CONFIG_PM_SLEEP ret = register_pm_notifier(&pm_notifier); if (ret) return ret; -#endif - ret = register_pernet_device(&pernet_ops); + ret = register_random_vmfork_notifier(&vm_notifier); if (ret) goto error_pm; + ret = register_pernet_device(&pernet_ops); + if (ret) + goto error_vm; + ret = rtnl_link_register(&link_ops); if (ret) goto error_pernet; @@ -442,10 +461,10 @@ int __init wg_device_init(void) error_pernet: unregister_pernet_device(&pernet_ops); +error_vm: + unregister_random_vmfork_notifier(&vm_notifier); error_pm: -#ifdef CONFIG_PM_SLEEP unregister_pm_notifier(&pm_notifier); -#endif return ret; } @@ -453,8 +472,7 @@ void wg_device_uninit(void) { rtnl_link_unregister(&link_ops); unregister_pernet_device(&pernet_ops); -#ifdef CONFIG_PM_SLEEP + unregister_random_vmfork_notifier(&vm_notifier); unregister_pm_notifier(&pm_notifier); -#endif rcu_barrier(); } diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index d0f3b6d7f408..43c8c84e7ea8 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -436,14 +436,13 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) if (attrs[WGPEER_A_ENDPOINT]) { struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]); size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]); + struct endpoint endpoint = { { { 0 } } }; - if ((len == sizeof(struct sockaddr_in) && - addr->sa_family == AF_INET) || - (len == sizeof(struct sockaddr_in6) && - addr->sa_family == AF_INET6)) { - struct endpoint endpoint = { { { 0 } } }; - - memcpy(&endpoint.addr, addr, len); + if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) { + endpoint.addr4 = *(struct sockaddr_in *)addr; + wg_socket_set_peer_endpoint(peer, &endpoint); + } else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) { + endpoint.addr6 = *(struct sockaddr_in6 *)addr; wg_socket_set_peer_endpoint(peer, &endpoint); } } @@ -621,6 +620,7 @@ static const struct genl_ops genl_ops[] = { static struct genl_family genl_family __ro_after_init = { .ops = genl_ops, .n_ops = ARRAY_SIZE(genl_ops), + .resv_start_op = WG_CMD_SET_DEVICE + 1, .name = WG_GENL_NAME, .version = WG_GENL_VERSION, .maxattr = WGDEVICE_A_MAX, diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c index c0cfd9b36c0b..720952b92e78 100644 --- a/drivers/net/wireguard/noise.c +++ b/drivers/net/wireguard/noise.c @@ -302,6 +302,41 @@ void wg_noise_set_static_identity_private_key( static_identity->static_public, private_key); } +static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 }; + u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32)); + int i; + + if (keylen > BLAKE2S_BLOCK_SIZE) { + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, key, keylen); + blake2s_final(&state, x_key); + } else + memcpy(x_key, key, keylen); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, in, inlen); + blake2s_final(&state, i_hash); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x5c ^ 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE); + blake2s_final(&state, i_hash); + + memcpy(out, i_hash, BLAKE2S_HASH_SIZE); + memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE); + memzero_explicit(i_hash, BLAKE2S_HASH_SIZE); +} + /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 @@ -322,14 +357,14 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, ((third_len || third_dst) && (!second_len || !second_dst)))); /* Extract entropy from data into secret */ - blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN); + hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN); if (!first_dst || !first_len) goto out; /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; - blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE); + hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE); memcpy(first_dst, output, first_len); if (!second_dst || !second_len) @@ -337,8 +372,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_HASH_SIZE] = 2; - blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, - BLAKE2S_HASH_SIZE); + hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); memcpy(second_dst, output, second_len); if (!third_dst || !third_len) @@ -346,8 +380,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, /* Expand third key: key = secret, data = second-key || 0x3 */ output[BLAKE2S_HASH_SIZE] = 3; - blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, - BLAKE2S_HASH_SIZE); + hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); memcpy(third_dst, output, third_len); out: diff --git a/drivers/net/wireguard/peer.c b/drivers/net/wireguard/peer.c index 1acd00ab2fbc..1cb502a932e0 100644 --- a/drivers/net/wireguard/peer.c +++ b/drivers/net/wireguard/peer.c @@ -54,8 +54,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, skb_queue_head_init(&peer->staged_packet_queue); wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake); set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state); - netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll, - NAPI_POLL_WEIGHT); + netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll); napi_enable(&peer->napi); list_add_tail(&peer->peer_list, &wg->peer_list); INIT_LIST_HEAD(&peer->allowedips_list); diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c index 1de413b19e34..8084e7408c0a 100644 --- a/drivers/net/wireguard/queueing.c +++ b/drivers/net/wireguard/queueing.c @@ -4,6 +4,7 @@ */ #include "queueing.h" +#include <linux/skb_array.h> struct multicore_worker __percpu * wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr) @@ -42,7 +43,7 @@ void wg_packet_queue_free(struct crypt_queue *queue, bool purge) { free_percpu(queue->worker); WARN_ON(!purge && !__ptr_ring_empty(&queue->ring)); - ptr_ring_cleanup(&queue->ring, purge ? (void(*)(void*))kfree_skb : NULL); + ptr_ring_cleanup(&queue->ring, purge ? __skb_array_destroy_skb : NULL); } #define NEXT(skb) ((skb)->prev) diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index e2388107f7fd..583adb37ee1e 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -79,9 +79,7 @@ static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating) u8 sw_hash = skb->sw_hash; u32 hash = skb->hash; skb_scrub_packet(skb, true); - memset(&skb->headers_start, 0, - offsetof(struct sk_buff, headers_end) - - offsetof(struct sk_buff, headers_start)); + memset(&skb->headers, 0, sizeof(skb->headers)); if (encapsulating) { skb->l4_hash = l4_hash; skb->sw_hash = sw_hash; diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index 7b8df406c773..7135d51d2d87 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -19,15 +19,8 @@ /* Must be called with bh disabled. */ static void update_rx_stats(struct wg_peer *peer, size_t len) { - struct pcpu_sw_netstats *tstats = - get_cpu_ptr(peer->device->dev->tstats); - - u64_stats_update_begin(&tstats->syncp); - ++tstats->rx_packets; - tstats->rx_bytes += len; + dev_sw_netstats_rx_add(peer->device->dev, len); peer->rx_bytes += len; - u64_stats_update_end(&tstats->syncp); - put_cpu_ptr(tstats); } #define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type) diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index e173204ae7d7..19eac00b2381 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -284,7 +284,7 @@ static __init bool randomized_test(void) mutex_lock(&mutex); for (i = 0; i < NUM_RAND_ROUTES; ++i) { - prandom_bytes(ip, 4); + get_random_bytes(ip, 4); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, @@ -299,7 +299,7 @@ static __init bool randomized_test(void) } for (j = 0; j < NUM_MUTATED_ROUTES; ++j) { memcpy(mutated, ip, 4); - prandom_bytes(mutate_mask, 4); + get_random_bytes(mutate_mask, 4); mutate_amount = prandom_u32_max(32); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; @@ -310,7 +310,7 @@ static __init bool randomized_test(void) for (k = 0; k < 4; ++k) mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & - prandom_u32_max(256)); + get_random_u8()); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; if (wg_allowedips_insert_v4(&t, @@ -328,7 +328,7 @@ static __init bool randomized_test(void) } for (i = 0; i < NUM_RAND_ROUTES; ++i) { - prandom_bytes(ip, 16); + get_random_bytes(ip, 16); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, @@ -343,7 +343,7 @@ static __init bool randomized_test(void) } for (j = 0; j < NUM_MUTATED_ROUTES; ++j) { memcpy(mutated, ip, 16); - prandom_bytes(mutate_mask, 16); + get_random_bytes(mutate_mask, 16); mutate_amount = prandom_u32_max(128); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; @@ -354,7 +354,7 @@ static __init bool randomized_test(void) for (k = 0; k < 4; ++k) mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & - prandom_u32_max(256)); + get_random_u8()); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; if (wg_allowedips_insert_v6(&t, @@ -381,13 +381,13 @@ static __init bool randomized_test(void) for (j = 0;; ++j) { for (i = 0; i < NUM_QUERIES; ++i) { - prandom_bytes(ip, 4); + get_random_bytes(ip, 4); if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); pr_err("allowedips random v4 self-test: FAIL\n"); goto free; } - prandom_bytes(ip, 16); + get_random_bytes(ip, 16); if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { pr_err("allowedips random v6 self-test: FAIL\n"); goto free; @@ -593,10 +593,10 @@ bool __init wg_allowedips_selftest(void) wg_allowedips_remove_by_peer(&t, a, &mutex); test_negative(4, a, 192, 168, 0, 1); - /* These will hit the WARN_ON(len >= 128) in free_node if something - * goes wrong. + /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_BITS) in free_node + * if something goes wrong. */ - for (i = 0; i < 128; ++i) { + for (i = 0; i < MAX_ALLOWEDIPS_BITS; ++i) { part = cpu_to_be64(~(1LLU << (i % 64))); memset(&ip, 0xff, 16); memcpy((u8 *)&ip + (i < 64) * 8, &part, 8); diff --git a/drivers/net/wireguard/selftest/ratelimiter.c b/drivers/net/wireguard/selftest/ratelimiter.c index 007cd4457c5f..d4bb40a695ab 100644 --- a/drivers/net/wireguard/selftest/ratelimiter.c +++ b/drivers/net/wireguard/selftest/ratelimiter.c @@ -167,7 +167,7 @@ bool __init wg_ratelimiter_selftest(void) ++test; #endif - for (trials = TRIALS_BEFORE_GIVING_UP;;) { + for (trials = TRIALS_BEFORE_GIVING_UP; IS_ENABLED(DEBUG_RATELIMITER_TIMINGS);) { int test_count = 0, ret; ret = timings_test(skb4, hdr4, skb6, hdr6, &test_count); @@ -176,7 +176,6 @@ bool __init wg_ratelimiter_selftest(void) test += test_count; goto err; } - msleep(500); continue; } else if (ret < 0) { test += test_count; @@ -195,7 +194,6 @@ bool __init wg_ratelimiter_selftest(void) test += test_count; goto err; } - msleep(50); continue; } test += test_count; diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index 6f07b949cb81..0414d7a6ce74 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -160,6 +160,7 @@ out: rcu_read_unlock_bh(); return ret; #else + kfree_skb(skb); return -EAFNOSUPPORT; #endif } @@ -241,7 +242,7 @@ int wg_socket_endpoint_from_skb(struct endpoint *endpoint, endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr; endpoint->src4.s_addr = ip_hdr(skb)->daddr; endpoint->src_if4 = skb->skb_iif; - } else if (skb->protocol == htons(ETH_P_IPV6)) { + } else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) { endpoint->addr6.sin6_family = AF_INET6; endpoint->addr6.sin6_port = udp_hdr(skb)->source; endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; @@ -284,7 +285,7 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer, peer->endpoint.addr4 = endpoint->addr4; peer->endpoint.src4 = endpoint->src4; peer->endpoint.src_if4 = endpoint->src_if4; - } else if (endpoint->addr.sa_family == AF_INET6) { + } else if (IS_ENABLED(CONFIG_IPV6) && endpoint->addr.sa_family == AF_INET6) { peer->endpoint.addr6 = endpoint->addr6; peer->endpoint.src6 = endpoint->src6; } else { |