aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/netlink.c')
-rw-r--r--src/netlink.c54
1 files changed, 40 insertions, 14 deletions
diff --git a/src/netlink.c b/src/netlink.c
index e0f3632..90ff936 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -27,6 +27,8 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
[WGDEVICE_A_PEERS] = { .type = NLA_NESTED },
[WGDEVICE_A_DEV_NETNS_PID] = { .type = NLA_U32 },
[WGDEVICE_A_DEV_NETNS_FD] = { .type = NLA_U32 },
+ [WGDEVICE_A_TRANSIT_NETNS_PID] = { .type = NLA_U32 },
+ [WGDEVICE_A_TRANSIT_NETNS_FD] = { .type = NLA_U32 },
};
static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
@@ -346,23 +348,50 @@ static int wg_get_device_done(struct netlink_callback *cb)
return 0;
}
-static int set_port(struct wg_device *wg, u16 port)
+static int set_socket(struct wg_device *wg, struct nlattr **attrs)
{
+ struct nlattr *netns_pid_attr, *netns_fd_attr, *port_attr;
+ struct net *net = NULL;
struct wg_peer *peer;
- int ret;
+ int ret = 0;
+ u16 port;
- ret = test_socket_net_capable(wg->transit_net);
- if (ret)
- return ret;
- if (wg->incoming_port == port)
+ netns_pid_attr = attrs[WGDEVICE_A_TRANSIT_NETNS_PID];
+ netns_fd_attr = attrs[WGDEVICE_A_TRANSIT_NETNS_FD];
+ port_attr = attrs[WGDEVICE_A_LISTEN_PORT];
+
+ if (!netns_pid_attr && !netns_fd_attr && !port_attr)
return 0;
+
+ net = get_attr_net(netns_pid_attr, netns_fd_attr);
+ if (IS_ERR(net))
+ return PTR_ERR(net);
+ if (port_attr)
+ port = nla_get_u16(port_attr);
+ else
+ port = wg->incoming_port;
+
+ ret = test_socket_net_capable(net ? : wg->transit_net);
+ if (ret)
+ goto out;
+
+ if (wg->incoming_port == port && (!net || wg->transit_net == net))
+ goto out;
+
list_for_each_entry(peer, &wg->peer_list, peer_list)
wg_socket_clear_peer_endpoint_src(peer);
if (!netif_running(wg->dev)) {
wg->incoming_port = port;
- return 0;
+ if (net)
+ wg_device_set_nets(wg, wg->dev_net, net);
+ goto out;
}
- return wg_socket_init(wg, wg->transit_net, port);
+ ret = wg_socket_init(wg, net ? : wg->transit_net, port);
+
+out:
+ if (net)
+ put_net(net);
+ return ret;
}
static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
@@ -559,12 +588,9 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
wg_socket_clear_peer_endpoint_src(peer);
}
- if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
- ret = set_port(wg,
- nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
- if (ret)
- goto out;
- }
+ ret = set_socket(wg, info->attrs);
+ if (ret)
+ goto out;
if (info->attrs[WGDEVICE_A_FLAGS] &&
nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) &