aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/socket.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/socket.c')
-rw-r--r--src/socket.c64
1 files changed, 26 insertions, 38 deletions
diff --git a/src/socket.c b/src/socket.c
index 1ce74cd..5bf5a92 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -20,7 +20,6 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
.saddr = endpoint->src4.s_addr,
.daddr = endpoint->addr4.sin_addr.s_addr,
.fl4_dport = endpoint->addr4.sin_port,
- .fl4_sport = htons(wg->incoming_port),
.flowi4_mark = wg->fwmark,
.flowi4_proto = IPPROTO_UDP
};
@@ -34,6 +33,7 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock4);
+ fl.fl4_sport = inet_sk(sock)->inet_sport;
if (unlikely(!sock)) {
ret = -ENONET;
@@ -89,7 +89,6 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
.saddr = endpoint->src6,
.daddr = endpoint->addr6.sin6_addr,
.fl6_dport = endpoint->addr6.sin6_port,
- .fl6_sport = htons(wg->incoming_port),
.flowi6_mark = wg->fwmark,
.flowi6_oif = endpoint->addr6.sin6_scope_id,
.flowi6_proto = IPPROTO_UDP
@@ -105,6 +104,7 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock6);
+ fl.fl6_sport = inet_sk(sock)->inet_sport;
if (unlikely(!sock)) {
ret = -ENONET;
@@ -309,87 +309,75 @@ static inline void set_sock_opts(struct socket *sock)
sk_set_memalloc(sock->sk);
}
-int socket_init(struct wireguard_device *wg)
+int socket_init(struct wireguard_device *wg, u16 port)
{
- int ret = 0;
+ int ret;
struct udp_tunnel_sock_cfg cfg = {
.sk_user_data = wg,
.encap_type = 1,
.encap_rcv = receive
};
- struct socket *new4 = NULL;
+ struct socket *new4 = NULL, *new6 = NULL;
struct udp_port_cfg port4 = {
.family = AF_INET,
.local_ip.s_addr = htonl(INADDR_ANY),
- .local_udp_port = htons(wg->incoming_port),
+ .local_udp_port = htons(port),
.use_udp_checksums = true
};
#if IS_ENABLED(CONFIG_IPV6)
int retries = 0;
- struct socket *new6 = NULL;
struct udp_port_cfg port6 = {
.family = AF_INET6,
.local_ip6 = IN6ADDR_ANY_INIT,
- .local_udp_port = htons(wg->incoming_port),
.use_udp6_tx_checksums = true,
.use_udp6_rx_checksums = true,
.ipv6_v6only = true
};
#endif
- mutex_lock(&wg->socket_update_lock);
#if IS_ENABLED(CONFIG_IPV6)
retry:
#endif
- if (rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)) || rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock))) {
- ret = -EADDRINUSE;
- goto out;
- }
ret = udp_sock_create(wg->creating_net, &port4, &new4);
if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
- goto out;
+ return ret;
}
- wg->incoming_port = ntohs(inet_sk(new4->sk)->inet_sport);
set_sock_opts(new4);
setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
- rcu_assign_pointer(wg->sock4, new4->sk);
#if IS_ENABLED(CONFIG_IPV6)
- if (!ipv6_mod_enabled())
- goto out;
- port6.local_udp_port = htons(wg->incoming_port);
- ret = udp_sock_create(wg->creating_net, &port6, &new6);
- if (ret < 0) {
- udp_tunnel_sock_release(new4);
- rcu_assign_pointer(wg->sock4, NULL);
- if (ret == -EADDRINUSE && !port4.local_udp_port && retries++ < 100)
- goto retry;
- if (!port4.local_udp_port)
- wg->incoming_port = 0;
- pr_err("%s: Could not create IPv6 socket\n", wg->dev->name);
- goto out;
+ if (ipv6_mod_enabled()) {
+ port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
+ ret = udp_sock_create(wg->creating_net, &port6, &new6);
+ if (ret < 0) {
+ udp_tunnel_sock_release(new4);
+ if (ret == -EADDRINUSE && !port && retries++ < 100)
+ goto retry;
+ pr_err("%s: Could not create IPv6 socket\n", wg->dev->name);
+ return ret;
+ }
+ set_sock_opts(new6);
+ setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
}
- set_sock_opts(new6);
- setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
- rcu_assign_pointer(wg->sock6, new6->sk);
#endif
-out:
- mutex_unlock(&wg->socket_update_lock);
- return ret;
+ socket_reinit(wg, new4 ? new4->sk : NULL, new6 ? new6->sk : NULL);
+ return 0;
}
-void socket_uninit(struct wireguard_device *wg)
+void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6)
{
struct sock *old4, *old6;
mutex_lock(&wg->socket_update_lock);
old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock));
old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock));
- rcu_assign_pointer(wg->sock4, NULL);
- rcu_assign_pointer(wg->sock6, NULL);
+ rcu_assign_pointer(wg->sock4, new4);
+ rcu_assign_pointer(wg->sock6, new6);
+ if (new4)
+ wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
mutex_unlock(&wg->socket_update_lock);
synchronize_rcu_bh();
synchronize_net();