diff options
Diffstat (limited to 'src/netlink.c')
-rw-r--r-- | src/netlink.c | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/src/netlink.c b/src/netlink.c index e0f3632..90ff936 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -27,6 +27,8 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { [WGDEVICE_A_PEERS] = { .type = NLA_NESTED }, [WGDEVICE_A_DEV_NETNS_PID] = { .type = NLA_U32 }, [WGDEVICE_A_DEV_NETNS_FD] = { .type = NLA_U32 }, + [WGDEVICE_A_TRANSIT_NETNS_PID] = { .type = NLA_U32 }, + [WGDEVICE_A_TRANSIT_NETNS_FD] = { .type = NLA_U32 }, }; static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { @@ -346,23 +348,50 @@ static int wg_get_device_done(struct netlink_callback *cb) return 0; } -static int set_port(struct wg_device *wg, u16 port) +static int set_socket(struct wg_device *wg, struct nlattr **attrs) { + struct nlattr *netns_pid_attr, *netns_fd_attr, *port_attr; + struct net *net = NULL; struct wg_peer *peer; - int ret; + int ret = 0; + u16 port; - ret = test_socket_net_capable(wg->transit_net); - if (ret) - return ret; - if (wg->incoming_port == port) + netns_pid_attr = attrs[WGDEVICE_A_TRANSIT_NETNS_PID]; + netns_fd_attr = attrs[WGDEVICE_A_TRANSIT_NETNS_FD]; + port_attr = attrs[WGDEVICE_A_LISTEN_PORT]; + + if (!netns_pid_attr && !netns_fd_attr && !port_attr) return 0; + + net = get_attr_net(netns_pid_attr, netns_fd_attr); + if (IS_ERR(net)) + return PTR_ERR(net); + if (port_attr) + port = nla_get_u16(port_attr); + else + port = wg->incoming_port; + + ret = test_socket_net_capable(net ? : wg->transit_net); + if (ret) + goto out; + + if (wg->incoming_port == port && (!net || wg->transit_net == net)) + goto out; + list_for_each_entry(peer, &wg->peer_list, peer_list) wg_socket_clear_peer_endpoint_src(peer); if (!netif_running(wg->dev)) { wg->incoming_port = port; - return 0; + if (net) + wg_device_set_nets(wg, wg->dev_net, net); + goto out; } - return wg_socket_init(wg, wg->transit_net, port); + ret = wg_socket_init(wg, net ? : wg->transit_net, port); + +out: + if (net) + put_net(net); + return ret; } static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) @@ -559,12 +588,9 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) wg_socket_clear_peer_endpoint_src(peer); } - if (info->attrs[WGDEVICE_A_LISTEN_PORT]) { - ret = set_port(wg, - nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); - if (ret) - goto out; - } + ret = set_socket(wg, info->attrs); + if (ret) + goto out; if (info->attrs[WGDEVICE_A_FLAGS] && nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & |