diff options
-rw-r--r-- | Makefile | 6 | ||||
-rw-r--r-- | common.c | 465 | ||||
-rw-r--r-- | common.h | 38 | ||||
-rw-r--r-- | ipm.c | 216 | ||||
-rw-r--r-- | ipm.h | 21 | ||||
-rw-r--r-- | lease.c | 48 | ||||
-rw-r--r-- | radix-trie.c | 124 | ||||
-rw-r--r-- | radix-trie.h | 41 | ||||
-rw-r--r-- | random.c | 25 | ||||
-rw-r--r-- | random.h | 1 | ||||
-rw-r--r-- | wg-dynamic-client.c | 550 | ||||
-rw-r--r-- | wg-dynamic-server.c | 268 |
12 files changed, 899 insertions, 904 deletions
@@ -43,10 +43,10 @@ COMPILE.c = @echo " CC $@"; COMPILE.c += $(BUILT_IN_COMPILE.c) endif -all: wg-dynamic-server +all: wg-dynamic-server wg-dynamic-client -wg-dynamic-client: wg-dynamic-client.o netlink.o common.o -wg-dynamic-server: wg-dynamic-server.o netlink.o radix-trie.o common.o random.o lease.o +wg-dynamic-client: wg-dynamic-client.o netlink.o common.o ipm.o +wg-dynamic-server: wg-dynamic-server.o netlink.o radix-trie.o common.o random.o lease.o ipm.o ifneq ($(V),1) clean: @@ -20,6 +20,60 @@ #include "common.h" #include "dbg.h" +union kvalues { + uint32_t u32; + struct wg_combined_ip ip; + char errmsg[256]; +}; + +static void request_ip(enum wg_dynamic_key key, union kvalues kv, void **dest) +{ + struct wg_combined_ip *ip = &kv.ip; + struct wg_dynamic_request_ip *r = (struct wg_dynamic_request_ip *)*dest; + + switch (key) { + case WGKEY_REQUEST_IP: + BUG_ON(*dest); + *dest = calloc(1, sizeof(struct wg_dynamic_request_ip)); + if (!*dest) + fatal("calloc()"); + + break; + case WGKEY_IPV4: + memcpy(&r->ipv4, &ip->ip4, sizeof r->ipv4); + r->cidrv4 = ip->cidr; + r->has_ipv4 = true; + break; + case WGKEY_IPV6: + memcpy(&r->ipv6, &ip->ip6, sizeof r->ipv6); + r->cidrv6 = ip->cidr; + r->has_ipv6 = true; + break; + case WGKEY_LEASESTART: + r->start = kv.u32; + break; + case WGKEY_LEASETIME: + r->leasetime = kv.u32; + break; + case WGKEY_ERRNO: + r->wg_errno = kv.u32; + break; + case WGKEY_ERRMSG: + r->errmsg = strdup(kv.errmsg); + break; + default: + debug("Invalid key %d, aborting\n", key); + BUG(); + } +} + +static void (*const deserialize_fptr[])(enum wg_dynamic_key key, + union kvalues kv, void **dest) = { + NULL, + NULL, + request_ip, +}; + static bool parse_ip_cidr(struct wg_combined_ip *ip, char *value) { uintmax_t res; @@ -50,281 +104,254 @@ static bool parse_ip_cidr(struct wg_combined_ip *ip, char *value) return true; } -static struct wg_dynamic_attr *parse_value(enum wg_dynamic_key key, char *value) +static bool parse_value(enum wg_dynamic_key key, char *str, union kvalues *kv) { - struct wg_dynamic_attr *attr; - size_t len; char *endptr; uintmax_t uresult; - union { - uint32_t uint32; - char errmsg[72]; - struct wg_combined_ip ip; - } data = { 0 }; + struct wg_combined_ip *ip; switch (key) { case WGKEY_IPV4: - len = sizeof data.ip; - data.ip.family = AF_INET; - if (!parse_ip_cidr(&data.ip, value)) - return NULL; - - break; case WGKEY_IPV6: - len = sizeof data.ip; - data.ip.family = AF_INET6; - if (!parse_ip_cidr(&data.ip, value)) - return NULL; + ip = &kv->ip; + ip->family = (key == WGKEY_IPV4) ? AF_INET : AF_INET6; + if (!parse_ip_cidr(ip, str)) + return false; break; + case WGKEY_REQUEST_IP: case WGKEY_LEASESTART: case WGKEY_LEASETIME: case WGKEY_ERRNO: - len = sizeof data.uint32; - uresult = strtoumax(value, &endptr, 10); + uresult = strtoumax(str, &endptr, 10); if (uresult > UINT32_MAX || *endptr != '\0') - return NULL; - data.uint32 = (uint32_t)uresult; + return false; + kv->u32 = (uint32_t)uresult; break; case WGKEY_ERRMSG: - strncpy(data.errmsg, value, sizeof data.errmsg - 1); - data.errmsg[sizeof data.errmsg - 1] = '\0'; - len = MIN(sizeof data.errmsg, - strlen(value) + 1); /* Copying the NUL byte too. */ - + strncpy(kv->errmsg, str, sizeof kv->errmsg); + kv->errmsg[sizeof kv->errmsg - 1] = '\0'; break; default: debug("Invalid key %d, aborting\n", key); BUG(); } - attr = malloc(sizeof(struct wg_dynamic_attr) + len); - if (!attr) - fatal("malloc()"); - - attr->len = len; - attr->key = key; - attr->next = NULL; - memcpy(&attr->value, &data, len); - - return attr; + return true; } static enum wg_dynamic_key parse_key(char *key) { - for (enum wg_dynamic_key e = 1; e < ARRAY_SIZE(WG_DYNAMIC_KEY); ++e) + for (enum wg_dynamic_key e = 2; e < ARRAY_SIZE(WG_DYNAMIC_KEY); ++e) if (!strcmp(key, WG_DYNAMIC_KEY[e])) return e; return WGKEY_UNKNOWN; } -/* Consumes one full line from buf, or up to MAX_LINESIZE-1 bytes if no newline - * character was found. - * If req != NULL then we expect to parse a command and will set cmd and version - * of req accordingly, while *attr will be set to NULL. - * Otherwise we expect to parse a normal key=value pair, that will be stored - * in a newly allocated wg_dynamic_attr, pointed to by *attr. +/* Consumes one full line from buf, or up to MAX_LINESIZE bytes if no newline + * character was found. If less then MAX_LINESIZE bytes are available, a new + * buffer will be allocated and req->buf and req->len set accordingly. * * Return values: * > 0 : Amount of bytes consumed (<= MAX_LINESIZE) + * = 0 : Consumed len bytes; need more for a full line * < 0 : Error - * = 0 : End of message */ static ssize_t parse_line(unsigned char *buf, size_t len, - struct wg_dynamic_attr **attr, - struct wg_dynamic_request *req) + struct wg_dynamic_request *req, + enum wg_dynamic_key *key, union kvalues *kv) { unsigned char *line_end, *key_end; - enum wg_dynamic_key key; ssize_t line_len; - char *endptr; - uintmax_t res; - line_end = memchr(buf, '\n', len > MAX_LINESIZE ? MAX_LINESIZE : len); + line_end = memchr(buf, '\n', MIN(len, MAX_LINESIZE)); if (!line_end) { if (len >= MAX_LINESIZE) return -E2BIG; - *attr = malloc(sizeof(struct wg_dynamic_attr) + len); - if (!*attr) + req->len = len; + req->buf = malloc(len); + if (!req->buf) fatal("malloc()"); - (*attr)->key = WGKEY_INCOMPLETE; - (*attr)->len = len; - (*attr)->next = NULL; - memcpy((*attr)->value, buf, len); - - return len; + memcpy(req->buf, buf, len); + return 0; } - if (line_end == buf) - return 0; /* \n\n - end of message */ + if (line_end == buf) { + *key = WGKEY_EOMSG; + return 1; + } *line_end = '\0'; line_len = line_end - buf + 1; key_end = memchr(buf, '=', line_len - 1); - if (!key_end) + if (!key_end || key_end == buf) return -EINVAL; *key_end = '\0'; - key = parse_key((char *)buf); - if (key == WGKEY_UNKNOWN) - return -ENOENT; - - if (req) { - if (key >= WGKEY_ENDCMD) - return -ENOENT; - - *attr = NULL; - res = strtoumax((char *)key_end + 1, &endptr, 10); - - if (res > UINT32_MAX || *endptr != '\0') - return -EINVAL; - - req->cmd = key; - req->version = (uint32_t)res; + *key = parse_key((char *)buf); + if (*key == WGKEY_UNKNOWN) + return line_len; - if (req->version != 1) - return -EPROTONOSUPPORT; - } else { - if (key <= WGKEY_ENDCMD) - return -ENOENT; - - *attr = parse_value(key, (char *)key_end + 1); - if (!*attr) - return -EINVAL; - } + if (!parse_value(*key, (char *)key_end + 1, kv)) + return -EINVAL; return line_len; } -static int parse_request(struct wg_dynamic_request *req, unsigned char *buf, - size_t len) +static ssize_t parse_request(struct wg_dynamic_request *req, unsigned char *buf, + size_t len) { - struct wg_dynamic_attr *attr; - size_t offset = 0; - ssize_t ret; + ssize_t ret, offset = 0; + enum wg_dynamic_key key; + union kvalues kv; + void (*deserialize)(enum wg_dynamic_key key, union kvalues kv, + void **dest); if (memchr(buf, '\0', len)) return -EINVAL; /* don't allow null bytes */ - if (req->last && req->last->key == WGKEY_INCOMPLETE) { - len += req->last->len; + if (req->cmd == WGKEY_UNKNOWN) { + ret = parse_line(buf, len, req, &req->cmd, &kv); + if (ret <= 0) + return ret; - memmove(buf + req->last->len, buf, len); - memcpy(buf, req->last->value, req->last->len); - free(req->last); + req->version = kv.u32; + if (req->cmd >= WGKEY_ENDCMD || req->cmd <= WGKEY_EOMSG || + req->version != 1) + return -EPROTONOSUPPORT; - if (req->first == req->last) { - req->first = NULL; - req->last = NULL; - } else { - attr = req->first; - while (attr->next != req->last) - attr = attr->next; + len -= ret; + offset += ret; - attr->next = NULL; - req->last = attr; - } + deserialize = deserialize_fptr[req->cmd]; + deserialize(req->cmd, kv, &req->result); + } else { + deserialize = deserialize_fptr[req->cmd]; } while (len > 0) { - ret = parse_line(buf + offset, len, &attr, - req->cmd == WGKEY_UNKNOWN ? req : NULL); + ret = parse_line(buf + offset, len, req, &key, &kv); if (ret <= 0) - return ret; /* either error or message complete */ + return ret; len -= ret; offset += ret; - if (!attr) - continue; - if (!req->first) - req->first = attr; - else - req->last->next = attr; + if (key == WGKEY_EOMSG) + return offset; + else if (key == WGKEY_UNKNOWN) + continue; + else if (key <= WGKEY_ENDCMD) + return -EINVAL; - req->last = attr; + deserialize(key, kv, &req->result); } - return 1; + return 0; } -bool handle_request(struct wg_dynamic_request *req, - bool (*success)(struct wg_dynamic_request *), - bool (*error)(struct wg_dynamic_request *, int)) +ssize_t handle_request(int fd, struct wg_dynamic_request *req, + unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE], + size_t *remaining) { - ssize_t bytes; - int ret; - unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE]; + ssize_t bytes, processed; + size_t leftover; - while (1) { - bytes = read(req->fd, buf, RECV_BUFSIZE); - if (bytes < 0) { - if (errno == EWOULDBLOCK || errno == EAGAIN) - break; + BUG_ON((*remaining > 0 && req->buf) || (req->buf && !req->len)); + + do { + leftover = req->len; + if (*remaining > 0) + bytes = *remaining; + else + bytes = read(fd, buf + leftover, RECV_BUFSIZE); - // TODO: handle EINTR + if (bytes < 0) { + if (errno == EWOULDBLOCK || errno == EAGAIN || + errno == EINTR) + return 0; - debug("Reading from socket %d failed: %s\n", req->fd, + debug("Reading from socket %d failed: %s\n", fd, strerror(errno)); - return true; + return -1; } else if (bytes == 0) { - debug("Peer disconnected unexpectedly\n"); - return true; + return -1; } - ret = parse_request(req, buf, bytes); - if (ret < 0) - return error(req, -ret); - else if (ret == 0) - return success(req); - } + if (req->buf) { + memcpy(buf, req->buf, leftover); + free(req->buf); + req->buf = NULL; + req->len = 0; + } - return false; + processed = parse_request(req, buf, bytes + leftover); + if (processed < 0) + return processed; /* Parsing error */ + if (!processed) + *remaining = 0; + } while (processed == 0); + + *remaining = (bytes + leftover) - processed; + memmove(buf, buf + processed, *remaining); + + return 1; } -bool send_message(struct wg_dynamic_request *req, const void *buf, size_t len) +void free_wg_dynamic_request(struct wg_dynamic_request *req) { - size_t offset = 0; + BUG_ON(req->buf || req->len); - while (1) { - ssize_t written = write(req->fd, buf + offset, len - offset); - if (written < 0) { - if (errno == EWOULDBLOCK || errno == EAGAIN) - break; + req->cmd = WGKEY_UNKNOWN; + req->version = 0; + if (req->result) { + free(((struct wg_dynamic_request_ip *)req->result)->errmsg); + free(req->result); + req->result = NULL; + } +} - // TODO: handle EINTR +size_t serialize_request_ip(bool send, char *buf, size_t len, + struct wg_dynamic_request_ip *rip) +{ + size_t off = 0; + char addrbuf[INET6_ADDRSTRLEN]; - debug("Writing to socket %d failed: %s\n", req->fd, - strerror(errno)); - return true; - } + if (send) + print_to_buf(buf, len, &off, "request_ip=1\n"); - offset += written; - if (offset == len) - return true; - } + if (rip->has_ipv4) { + if (!inet_ntop(AF_INET, &rip->ipv4, addrbuf, sizeof addrbuf)) + fatal("inet_ntop()"); - debug("Socket %d blocking on write with %lu bytes left, postponing\n", - req->fd, len - offset); + print_to_buf(buf, len, &off, "ipv4=%s/32\n", addrbuf); + } - if (!req->buf) { - req->buflen = len - offset; - req->buf = malloc(req->buflen); - if (!req->buf) - fatal("malloc()"); + if (rip->has_ipv6) { + if (!inet_ntop(AF_INET6, &rip->ipv6, addrbuf, sizeof addrbuf)) + fatal("inet_ntop()"); - memcpy(req->buf, buf + offset, req->buflen); - } else { - req->buflen = len - offset; - memmove(req->buf, buf + offset, req->buflen); + print_to_buf(buf, len, &off, "ipv6=%s/128\n", addrbuf); } - return false; + if (rip->start && rip->leasetime) + print_to_buf(buf, len, &off, "leasestart=%u\nleasetime=%u\n", + rip->start, rip->leasetime); + + if (rip->errmsg) + print_to_buf(buf, len, &off, "errmsg=%s\n", rip->errmsg); + + if (!send) + print_to_buf(buf, len, &off, "errno=%u\n", rip->wg_errno); + + print_to_buf(buf, len, &off, "\n"); + + return off; } void print_to_buf(char *buf, size_t bufsize, size_t *offset, char *fmt, ...) @@ -345,106 +372,8 @@ void print_to_buf(char *buf, size_t bufsize, size_t *offset, char *fmt, ...) *offset += n; } -uint32_t current_time() -{ - struct timespec tp; - if (clock_gettime(CLOCK_REALTIME, &tp)) - fatal("clock_gettime(CLOCK_REALTIME)"); - return tp.tv_sec; -} - -void close_connection(struct wg_dynamic_request *req) -{ - struct wg_dynamic_attr *prev, *cur = req->first; - - if (close(req->fd)) - debug("Failed to close socket\n"); - - while (cur) { - prev = cur; - cur = cur->next; - free(prev); - } - - req->cmd = WGKEY_UNKNOWN; - req->version = 0; - req->fd = -1; - free(req->buf); - req->buf = NULL; - req->buflen = 0; - req->first = NULL; - req->last = NULL; -} - bool is_link_local(unsigned char *addr) { /* TODO: check if the remaining 54 bits are 0 */ return IN6_IS_ADDR_LINKLOCAL(addr); } - -void iface_get_all_addrs(uint8_t family, mnl_cb_t data_cb, void *cb_data) -{ - struct mnl_socket *nl; - char buf[MNL_SOCKET_BUFFER_SIZE]; - struct nlmsghdr *nlh; - /* TODO: rtln-addr-dump from libmnl uses rtgenmsg here? */ - struct ifaddrmsg *ifaddr; - int ret; - unsigned int seq, portid; - - nl = mnl_socket_open(NETLINK_ROUTE); - if (nl == NULL) - fatal("mnl_socket_open"); - - if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) - fatal("mnl_socket_bind"); - - /* You'd think that we could just request addresses from a specific - * interface, via NLM_F_MATCH or something, but we can't. See also: - * https://marc.info/?l=linux-netdev&m=132508164508217 - */ - seq = time(NULL); - portid = mnl_socket_get_portid(nl); - nlh = mnl_nlmsg_put_header(buf); - nlh->nlmsg_type = RTM_GETADDR; - nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - nlh->nlmsg_seq = seq; - ifaddr = mnl_nlmsg_put_extra_header(nlh, sizeof(struct ifaddrmsg)); - ifaddr->ifa_family = family; - - if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) - fatal("mnl_socket_sendto"); - - do { - ret = mnl_socket_recvfrom(nl, buf, sizeof(buf)); - if (ret <= MNL_CB_STOP) - break; - ret = mnl_cb_run(buf, ret, seq, portid, data_cb, cb_data); - } while (ret > 0); - - if (ret == -1) - fatal("mnl_cb_run/mnl_socket_recvfrom"); - - mnl_socket_close(nl); -} - -int data_attr_cb(const struct nlattr *attr, void *data) -{ - const struct nlattr **tb = data; - int type = mnl_attr_get_type(attr); - - /* skip unsupported attribute in user-space */ - if (mnl_attr_type_valid(attr, IFA_MAX) < 0) - return MNL_CB_OK; - - switch (type) { - case IFA_ADDRESS: - if (mnl_attr_validate(attr, MNL_TYPE_BINARY) < 0) { - perror("mnl_attr_validate"); - return MNL_CB_ERROR; - } - break; - } - tb[type] = attr; - return MNL_CB_OK; -} @@ -15,20 +15,17 @@ #include "netlink.h" #define MAX_CONNECTIONS 16 - #define MAX_LINESIZE 4096 - #define RECV_BUFSIZE 8192 - #define MAX_RESPONSE_SIZE 8192 static const char WG_DYNAMIC_ADDR[] = "fe80::"; static const uint16_t WG_DYNAMIC_PORT = 970; /* ASCII sum of "wireguard" */ - #define WG_DYNAMIC_DEFAULT_LEASETIME 3600 #define ITEMS \ E(WGKEY_UNKNOWN, "") /* must be the first entry */ \ + E(WGKEY_EOMSG, "") \ /* CMD START */ \ E(WGKEY_REQUEST_IP, "request_ip") \ E(WGKEY_ENDCMD, "") \ @@ -62,13 +59,6 @@ static const char *const WG_DYNAMIC_ERR[] = { ITEMS }; #undef E #undef ITEMS -struct wg_dynamic_attr { - enum wg_dynamic_key key; - size_t len; - struct wg_dynamic_attr *next; - unsigned char value[]; -}; - struct wg_dynamic_request { enum wg_dynamic_key cmd; uint32_t version; @@ -76,8 +66,17 @@ struct wg_dynamic_request { wg_key pubkey; struct in6_addr lladdr; unsigned char *buf; - size_t buflen; - struct wg_dynamic_attr *first, *last; + size_t len; /* <= MAX_LINESIZE */ + void *result; +}; + +struct wg_dynamic_request_ip { + struct in_addr ipv4; + struct in6_addr ipv6; + uint8_t cidrv4, cidrv6; + uint32_t leasetime, start, wg_errno; + bool has_ipv4, has_ipv6; + char *errmsg; }; struct wg_combined_ip { @@ -91,15 +90,12 @@ struct wg_combined_ip { #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0])) +ssize_t handle_request(int fd, struct wg_dynamic_request *req, + unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE], + size_t *remaining); void free_wg_dynamic_request(struct wg_dynamic_request *req); -bool handle_request(struct wg_dynamic_request *req, - bool (*success)(struct wg_dynamic_request *), - bool (*error)(struct wg_dynamic_request *, int)); -bool send_message(struct wg_dynamic_request *req, const void *buf, size_t len); +size_t serialize_request_ip(bool include_header, char *buf, size_t len, + struct wg_dynamic_request_ip *rip); void print_to_buf(char *buf, size_t bufsize, size_t *offset, char *fmt, ...); -uint32_t current_time(); -void close_connection(struct wg_dynamic_request *req); bool is_link_local(unsigned char *addr); -void iface_get_all_addrs(uint8_t family, mnl_cb_t data_cb, void *cb_data); -int data_attr_cb(const struct nlattr *attr, void *data); #endif @@ -0,0 +1,216 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#include <arpa/inet.h> +#include <libmnl/libmnl.h> +#include <linux/rtnetlink.h> +#include <stdint.h> +#include <time.h> + +#include "dbg.h" +#include "common.h" + +struct mnl_cb_data { + uint32_t ifindex; + struct wg_combined_ip *ip; + bool ip_found; + bool duplicate; +}; + +static struct mnl_socket *nl = NULL; + +static int iface_update(uint16_t cmd, uint16_t flags, uint32_t ifindex, + const uint8_t *addr, uint8_t cidr, sa_family_t family) +{ + char buf[MNL_SOCKET_BUFFER_SIZE]; + struct nlmsghdr *nlh; + unsigned int seq, portid; + struct ifaddrmsg *ifaddr; + int ret; + + portid = mnl_socket_get_portid(nl); + nlh = mnl_nlmsg_put_header(buf); + nlh->nlmsg_type = cmd; + nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | flags; + nlh->nlmsg_seq = seq = time(NULL); + ifaddr = mnl_nlmsg_put_extra_header(nlh, sizeof(struct ifaddrmsg)); + ifaddr->ifa_family = family; + ifaddr->ifa_prefixlen = cidr; + ifaddr->ifa_scope = RT_SCOPE_UNIVERSE; + ifaddr->ifa_index = ifindex; + mnl_attr_put(nlh, IFA_LOCAL, family == AF_INET ? 4 : 16, addr); + + if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) + return -1; + + do { + ret = mnl_socket_recvfrom(nl, buf, sizeof(buf)); + if (ret <= MNL_CB_STOP) + break; + ret = mnl_cb_run(buf, ret, seq, portid, NULL, NULL); + } while (ret > 0); + + if (ret == -1) + return -1; + + return 0; +} + +static int data_attr_cb(const struct nlattr *attr, void *data) +{ + const struct nlattr **tb = data; + int type = mnl_attr_get_type(attr); + + /* skip unsupported attribute in user-space */ + if (mnl_attr_type_valid(attr, IFA_MAX) < 0) + return MNL_CB_OK; + + switch (type) { + case IFA_ADDRESS: + if (mnl_attr_validate(attr, MNL_TYPE_BINARY) < 0) { + perror("mnl_attr_validate"); + return MNL_CB_ERROR; + } + break; + } + tb[type] = attr; + return MNL_CB_OK; +} + +static int data_cb(const struct nlmsghdr *nlh, void *data) +{ + struct nlattr *tb[IFA_MAX + 1] = {}; + struct ifaddrmsg *ifa = mnl_nlmsg_get_payload(nlh); + struct mnl_cb_data *cb_data = (struct mnl_cb_data *)data; + + if (ifa->ifa_index != cb_data->ifindex) + return MNL_CB_OK; + + if (ifa->ifa_scope != RT_SCOPE_LINK) + return MNL_CB_OK; + + mnl_attr_parse(nlh, sizeof(*ifa), data_attr_cb, tb); + + if (!tb[IFA_ADDRESS]) + return MNL_CB_OK; + + if (cb_data->ip_found) { + cb_data->duplicate = true; + return MNL_CB_OK; + } + + memcpy(cb_data->ip, mnl_attr_get_payload(tb[IFA_ADDRESS]), + ifa->ifa_family == AF_INET ? 4 : 16); + cb_data->ip->cidr = ifa->ifa_prefixlen; + cb_data->ip->family = ifa->ifa_family; + + char out[INET6_ADDRSTRLEN]; + inet_ntop(ifa->ifa_family, cb_data->ip, out, sizeof(out)); + debug("index=%d, family=%d, addr=%s\n", ifa->ifa_index, ifa->ifa_family, + out); + + cb_data->ip_found = true; + + return MNL_CB_OK; +} + +static int iface_get_all_addrs(uint8_t family, void *cb_data) +{ + char buf[MNL_SOCKET_BUFFER_SIZE]; + struct nlmsghdr *nlh; + /* TODO: rtln-addr-dump from libmnl uses rtgenmsg here? */ + struct ifaddrmsg *ifaddr; + int ret; + unsigned int seq, portid; + + /* You'd think that we could just request addresses from a specific + * interface, via NLM_F_MATCH or something, but we can't. See also: + * https://marc.info/?l=linux-netdev&m=132508164508217 + */ + seq = time(NULL); + portid = mnl_socket_get_portid(nl); + nlh = mnl_nlmsg_put_header(buf); + nlh->nlmsg_type = RTM_GETADDR; + nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + nlh->nlmsg_seq = seq; + ifaddr = mnl_nlmsg_put_extra_header(nlh, sizeof(struct ifaddrmsg)); + ifaddr->ifa_family = family; + + if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) + return -1; + + do { + ret = mnl_socket_recvfrom(nl, buf, sizeof(buf)); + if (ret <= MNL_CB_STOP) + break; + ret = mnl_cb_run(buf, ret, seq, portid, data_cb, cb_data); + } while (ret > 0); + + if (ret == -1) + return -1; + + return 0; +} + +void ipm_init() +{ + nl = mnl_socket_open(NETLINK_ROUTE); + if (nl == NULL) + fatal("mnl_socket_open()"); + + if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) + fatal("mnl_socket_bind()"); +} + +void ipm_free() +{ + if (nl) + mnl_socket_close(nl); +} + +int ipm_newaddr_v4(uint32_t ifindex, const struct in_addr *ip) +{ + return iface_update(RTM_NEWADDR, NLM_F_REPLACE | NLM_F_CREATE, ifindex, + (uint8_t *)ip, 32, AF_INET); +} + +int ipm_newaddr_v6(uint32_t ifindex, const struct in6_addr *ip) +{ + return iface_update(RTM_NEWADDR, NLM_F_REPLACE | NLM_F_CREATE, ifindex, + (uint8_t *)ip, 128, AF_INET6); +} + +int ipm_deladdr_v4(uint32_t ifindex, const struct in_addr *ip) +{ + return iface_update(RTM_DELADDR, 0, ifindex, (uint8_t *)ip, 32, + AF_INET); +} + +int ipm_deladdr_v6(uint32_t ifindex, const struct in6_addr *ip) +{ + return iface_update(RTM_DELADDR, 0, ifindex, (uint8_t *)ip, 128, + AF_INET6); +} + +int ipm_getlladdr(uint32_t ifindex, struct wg_combined_ip *addr) +{ + struct mnl_cb_data cb_data = { + .ifindex = ifindex, + .ip = addr, + .ip_found = false, + .duplicate = false, + }; + + if (iface_get_all_addrs(AF_INET6, &cb_data)) + return -1; + + if (!cb_data.ip_found) + return -2; + + if (cb_data.duplicate) + return -3; + + return 0; +} @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#ifndef __IPM_H__ +#define __IPM_H__ + +#include <stdint.h> + +#include "common.h" + +void ipm_init(); +void ipm_free(); +int ipm_newaddr_v4(uint32_t ifindex, const struct in_addr *ip); +int ipm_newaddr_v6(uint32_t ifindex, const struct in6_addr *ip); +int ipm_deladdr_v4(uint32_t ifindex, const struct in_addr *ip); +int ipm_deladdr_v6(uint32_t ifindex, const struct in6_addr *ip); +int ipm_getlladdr(uint32_t ifindex, struct wg_combined_ip *addr); + +#endif @@ -26,7 +26,7 @@ #define TIME_T_MAX (((time_t)1 << (sizeof(time_t) * CHAR_BIT - 2)) - 1) * 2 + 1 -static struct ip_pool pool; +static struct ipns ipns; static time_t gexpires = TIME_T_MAX; static bool synchronized; @@ -61,7 +61,7 @@ void leases_init(char *fname, struct mnl_socket *nlsock) synchronized = false; leases_ht = kh_init(leaseht); - ipp_init(&pool); + ipp_init(&ipns); nlh = mnl_nlmsg_put_header(buf); nlh->nlmsg_type = RTM_GETROUTE; @@ -88,7 +88,7 @@ void leases_free() } kh_destroy(leaseht, leases_ht); - ipp_free(&pool); + ipp_free(&ipns); } static struct wg_dynamic_lease *new_lease(const struct wg_dynamic_lease *lease, @@ -152,7 +152,7 @@ struct wg_dynamic_lease *set_lease(const char *devname, wg_key pubkey, sizeof(ip_asc)); debug("deleting from pool: %s\n", ip_asc); - if (ipp_del_v4(&pool, ¤t->ipv4, 32)) + if (ipp_del_v4(&ipns, ¤t->ipv4, 32)) die("ipp_del_v4()\n"); } @@ -162,26 +162,26 @@ struct wg_dynamic_lease *set_lease(const char *devname, wg_key pubkey, sizeof(ip_asc)); debug("deleting from pool: %s\n", ip_asc); - if (ipp_del_v6(&pool, ¤t->ipv6, 128)) + if (ipp_del_v6(&ipns, ¤t->ipv6, 128)) die("ipp_del_v6()\n"); } } if (wants_ipv4 && !new->ipv4.s_addr) { - if (!pool.total_ipv4) { + if (!ipns.total_ipv4) { debug("IPv4 pool empty\n"); } else if (!ipv4) { - index = random_bounded(pool.total_ipv4 - 1); - debug("new_lease(v4): %u of %u\n", index, - pool.total_ipv4); + index = random_bounded(ipns.total_ipv4 - 1); + debug("new_lease(v4): %u of %ju\n", index, + ipns.total_ipv4); - ipp_addnth_v4(&pool, &new->ipv4, index); + ipp_addnth_v4(&ipns, &new->ipv4, index); } else { char ip_asc[INET_ADDRSTRLEN]; inet_ntop(AF_INET, ipv4, ip_asc, sizeof(ip_asc)); debug("wants %s: ", ip_asc); - if (!ipp_add_v4(&pool, ipv4, 32)) { + if (!ipp_add_v4(&ipns, ipv4, 32)) { debug("allocated\n"); new->ipv4 = *ipv4; @@ -192,26 +192,26 @@ struct wg_dynamic_lease *set_lease(const char *devname, wg_key pubkey, } if (wants_ipv6 && IN6_IS_ADDR_UNSPECIFIED(&new->ipv6)) { - if (!pool.totalh_ipv6 && !pool.totall_ipv6) { + if (!ipns.totalh_ipv6 && !ipns.totall_ipv6) { debug("IPv6 pool empty\n"); } else if (!ipv6) { - if (pool.totalh_ipv6 > 0) { + if (ipns.totalh_ipv6 > 0) { index_l = random_bounded(UINT64_MAX); - index_h = random_bounded(pool.totalh_ipv6 - 1); + index_h = random_bounded(ipns.totalh_ipv6 - 1); } else { - index_l = random_bounded(pool.totall_ipv6 - 1); + index_l = random_bounded(ipns.totall_ipv6); index_h = 0; } debug("new_lease(v6): %u:%ju of %u:%ju\n", index_h, - index_l, pool.totalh_ipv6, pool.totall_ipv6); - ipp_addnth_v6(&pool, &new->ipv6, index_l, index_h); + index_l, ipns.totalh_ipv6, ipns.totall_ipv6); + ipp_addnth_v6(&ipns, &new->ipv6, index_l, index_h); } else { char ip_asc[INET6_ADDRSTRLEN]; inet_ntop(AF_INET6, ipv6, ip_asc, sizeof(ip_asc)); debug("wants %s: ", ip_asc); - if (!ipp_add_v6(&pool, ipv6, 128)) { + if (!ipp_add_v6(&ipns, ipv6, 128)) { debug("allocated\n"); new->ipv6 = *ipv6; @@ -398,10 +398,10 @@ int leases_refresh(const char *devname) time_t expires = lease->start_mono + lease->leasetime; if (cur_time >= expires) { if (lease->ipv4.s_addr) - ipp_del_v4(&pool, &lease->ipv4, 32); + ipp_del_v4(&ipns, &lease->ipv4, 32); if (!IN6_IS_ADDR_UNSPECIFIED(&lease->ipv6)) - ipp_del_v6(&pool, &lease->ipv6, 128); + ipp_del_v6(&ipns, &lease->ipv6, 128); memcpy(updates[i].peer_pubkey, kh_key(leases_ht, k), sizeof(wg_key)); @@ -500,18 +500,18 @@ static int process_nlpacket_cb(const struct nlmsghdr *nlh, void *data) if (nlh->nlmsg_type == RTM_NEWROUTE) { if (rm->rtm_family == AF_INET) { - if (ipp_addpool_v4(&pool, addr, rm->rtm_dst_len)) + if (ipp_addpool_v4(&ipns, addr, rm->rtm_dst_len)) die("ipp_addpool_v4()\n"); } else if (rm->rtm_family == AF_INET6) { - if (ipp_addpool_v6(&pool, addr, rm->rtm_dst_len)) + if (ipp_addpool_v6(&ipns, addr, rm->rtm_dst_len)) die("ipp_addpool_v6()\n"); } } else if (nlh->nlmsg_type == RTM_DELROUTE) { if (rm->rtm_family == AF_INET) { - if (ipp_removepool_v4(&pool, addr) && synchronized) + if (ipp_removepool_v4(&ipns, addr) && synchronized) die("ipp_removepool_v4()\n"); } else if (rm->rtm_family == AF_INET6) { - if (ipp_removepool_v6(&pool, addr) && synchronized) + if (ipp_removepool_v6(&ipns, addr) && synchronized) die("ipp_removepool_v6()\n"); } } diff --git a/radix-trie.c b/radix-trie.c index 25bdd75..3557063 100644 --- a/radix-trie.c +++ b/radix-trie.c @@ -377,35 +377,35 @@ static int remove_node(struct radix_node *trie, const uint8_t *key, return 0; } -static void totalip_inc(struct ip_pool *ipp, uint8_t bits, uint8_t val) +static void totalip_inc(struct ipns *ns, uint8_t bits, uint8_t val) { if (bits == 32) { - BUG_ON(val >= 32); - ipp->total_ipv4 += 1ULL << val; + BUG_ON(val > 32); + ns->total_ipv4 += 1ULL << val; } else if (bits == 128) { - uint64_t tmp = ipp->totall_ipv6; + uint64_t tmp = ns->totall_ipv6; BUG_ON(val > 64); - ipp->totall_ipv6 += (val == 64) ? 0 : 1ULL << val; - if (ipp->totall_ipv6 <= tmp) - ++ipp->totalh_ipv6; + ns->totall_ipv6 += (val == 64) ? 0 : 1ULL << val; + if (ns->totall_ipv6 <= tmp) + ++ns->totalh_ipv6; } } -static void totalip_dec(struct ip_pool *ipp, uint8_t bits, uint8_t val) +static void totalip_dec(struct ipns *ns, uint8_t bits, uint8_t val) { if (bits == 32) { - BUG_ON(val >= 32); - ipp->total_ipv4 -= 1ULL << val; + BUG_ON(val > 32); + ns->total_ipv4 -= 1ULL << val; } else if (bits == 128) { - uint64_t tmp = ipp->totall_ipv6; + uint64_t tmp = ns->totall_ipv6; BUG_ON(val > 64); - ipp->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val; - if (ipp->totall_ipv6 >= tmp) - --ipp->totalh_ipv6; + ns->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val; + if (ns->totall_ipv6 >= tmp) + --ns->totalh_ipv6; } } -static int ipp_addpool(struct ip_pool *ipp, struct radix_pool **pool, +static int ipp_addpool(struct ipns *ns, struct radix_pool **pool, struct radix_node **root, uint8_t bits, const uint8_t *key, uint8_t cidr) { @@ -421,7 +421,7 @@ static int ipp_addpool(struct ip_pool *ipp, struct radix_pool **pool, shadowed = true; } else if (cidr < node->cidr && !(*pool)->shadowed) { (*pool)->shadowed = true; - totalip_dec(ipp, bits, bits - cidr); + totalip_dec(ns, bits, bits - cidr); } else { return -1; } @@ -439,7 +439,7 @@ static int ipp_addpool(struct ip_pool *ipp, struct radix_pool **pool, } if (!shadowed) - totalip_inc(ipp, bits, bits - cidr); + totalip_inc(ns, bits, bits - cidr); newpool = malloc(sizeof *newpool); if (!newpool) @@ -505,94 +505,94 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits) debug_print_trie(root->bit[1], bits); } -void debug_print_trie_v4(struct ip_pool *pool) +void debug_print_trie_v4(struct ipns *ns) { - debug_print_trie(pool->ip4_root, 32); + debug_print_trie(ns->ip4_root, 32); } -void debug_print_trie_v6(struct ip_pool *pool) +void debug_print_trie_v6(struct ipns *ns) { - debug_print_trie(pool->ip6_root, 128); + debug_print_trie(ns->ip6_root, 128); } #endif -void ipp_init(struct ip_pool *pool) +void ipp_init(struct ipns *ns) { - pool->ip4_root = pool->ip6_root = NULL; - pool->ip4_pool = pool->ip6_pool = NULL; - pool->totall_ipv6 = pool->totalh_ipv6 = pool->total_ipv4 = 0; + ns->ip4_root = ns->ip6_root = NULL; + ns->ip4_pools = ns->ip6_pools = NULL; + ns->totall_ipv6 = ns->totalh_ipv6 = ns->total_ipv4 = 0; } -void ipp_free(struct ip_pool *pool) +void ipp_free(struct ipns *ns) { struct radix_pool *next; - radix_free_nodes(pool->ip4_root); - radix_free_nodes(pool->ip6_root); + radix_free_nodes(ns->ip4_root); + radix_free_nodes(ns->ip6_root); - for (struct radix_pool *cur = pool->ip4_pool; cur; cur = next) { + for (struct radix_pool *cur = ns->ip4_pools; cur; cur = next) { next = cur->next; free(cur); } - for (struct radix_pool *cur = pool->ip6_pool; cur; cur = next) { + for (struct radix_pool *cur = ns->ip6_pools; cur; cur = next) { next = cur->next; free(cur); } } -int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr) +int ipp_add_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr) { - int ret = insert_v4(&pool->ip4_root, ip, cidr); + int ret = insert_v4(&ns->ip4_root, ip, cidr); if (!ret) - --pool->total_ipv4; + --ns->total_ipv4; return ret; } -int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr) +int ipp_add_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr) { - int ret = insert_v6(&pool->ip6_root, ip, cidr); + int ret = insert_v6(&ns->ip6_root, ip, cidr); if (!ret) { - if (pool->totall_ipv6 == 0) - --pool->totalh_ipv6; + if (ns->totall_ipv6 == 0) + --ns->totalh_ipv6; - --pool->totall_ipv6; + --ns->totall_ipv6; } return ret; } -int ipp_del_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr) +int ipp_del_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr) { uint8_t key[4] __aligned(__alignof(uint32_t)); int ret; swap_endian(key, (const uint8_t *)ip, 32); - ret = remove_node(pool->ip4_root, key, cidr); + ret = remove_node(ns->ip4_root, key, cidr); if (!ret) - ++pool->total_ipv4; + ++ns->total_ipv4; return ret; } -int ipp_del_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr) +int ipp_del_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr) { uint8_t key[16] __aligned(__alignof(uint64_t)); int ret; swap_endian(key, (const uint8_t *)ip, 128); - ret = remove_node(pool->ip6_root, key, cidr); + ret = remove_node(ns->ip6_root, key, cidr); if (!ret) { - ++pool->totall_ipv6; - if (pool->totall_ipv6 == 0) - ++pool->totalh_ipv6; + ++ns->totall_ipv6; + if (ns->totall_ipv6 == 0) + ++ns->totalh_ipv6; } return ret; } -int ipp_addpool_v4(struct ip_pool *ipp, const struct in_addr *ip, uint8_t cidr) +int ipp_addpool_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr) { uint8_t key[4] __aligned(__alignof(uint32_t)); @@ -600,10 +600,10 @@ int ipp_addpool_v4(struct ip_pool *ipp, const struct in_addr *ip, uint8_t cidr) return -1; swap_endian(key, (const uint8_t *)ip, 32); - return ipp_addpool(ipp, &ipp->ip4_pool, &ipp->ip4_root, 32, key, cidr); + return ipp_addpool(ns, &ns->ip4_pools, &ns->ip4_root, 32, key, cidr); } -int ipp_addpool_v6(struct ip_pool *ipp, const struct in6_addr *ip, uint8_t cidr) +int ipp_addpool_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr) { uint8_t key[16] __aligned(__alignof(uint64_t)); @@ -611,26 +611,26 @@ int ipp_addpool_v6(struct ip_pool *ipp, const struct in6_addr *ip, uint8_t cidr) return -1; swap_endian(key, (const uint8_t *)ip, 128); - return ipp_addpool(ipp, &ipp->ip6_pool, &ipp->ip6_root, 128, key, cidr); + return ipp_addpool(ns, &ns->ip6_pools, &ns->ip6_root, 128, key, cidr); } /* TODO: implement */ -int ipp_removepool_v4(struct ip_pool *pool, const struct in_addr *ip) +int ipp_removepool_v4(struct ipns *ns, const struct in_addr *ip) { return 0; } /* TODO: implement */ -int ipp_removepool_v6(struct ip_pool *pool, const struct in6_addr *ip) +int ipp_removepool_v6(struct ipns *ns, const struct in6_addr *ip) { return 0; } -void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index) +void ipp_addnth_v4(struct ipns *ns, struct in_addr *dest, uint32_t index) { - struct radix_pool *current = pool->ip4_pool; + struct radix_pool *current = ns->ip4_pools; - for (current = pool->ip4_pool; current; current = current->next) { + for (current = ns->ip4_pools; current; current = current->next) { if (current->shadowed) continue; @@ -643,13 +643,13 @@ void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index) BUG_ON(!current); add_nth(current->node, 32, index, (uint8_t *)&dest->s_addr); - --pool->total_ipv4; + --ns->total_ipv4; } -void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, - uint32_t index_low, uint64_t index_high) +void ipp_addnth_v6(struct ipns *ns, struct in6_addr *dest, uint32_t index_low, + uint64_t index_high) { - struct radix_pool *current = pool->ip6_pool; + struct radix_pool *current = ns->ip6_pools; uint64_t tmp; while (current) { @@ -676,8 +676,8 @@ void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, BUG_ON(!current || index_high); add_nth(current->node, 128, index_low, (uint8_t *)&dest->s6_addr); - if (pool->totall_ipv6 == 0) - --pool->totalh_ipv6; + if (ns->totall_ipv6 == 0) + --ns->totalh_ipv6; - --pool->totall_ipv6; + --ns->totall_ipv6; } diff --git a/radix-trie.h b/radix-trie.h index a72ee50..d3599dc 100644 --- a/radix-trie.h +++ b/radix-trie.h @@ -10,37 +10,38 @@ #include <stdbool.h> #include <stdint.h> -struct ip_pool { - uint64_t totall_ipv6; - uint32_t totalh_ipv6, total_ipv4; +struct ipns { + /* Total amount of available addresses over all pools */ + uint64_t totall_ipv6, total_ipv4; + uint32_t totalh_ipv6; + struct radix_node *ip4_root, *ip6_root; - struct radix_pool *ip4_pool, *ip6_pool; + struct radix_pool *ip4_pools, *ip6_pools; }; -void ipp_init(struct ip_pool *pool); -void ipp_free(struct ip_pool *pool); +void ipp_init(struct ipns *ns); +void ipp_free(struct ipns *ns); -int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr); -int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr); +int ipp_add_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr); +int ipp_add_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr); -int ipp_del_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr); -int ipp_del_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr); +int ipp_del_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr); +int ipp_del_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr); -void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index); -void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, - uint32_t index_low, uint64_t index_high); +void ipp_addnth_v4(struct ipns *ns, struct in_addr *dest, uint32_t index); +void ipp_addnth_v6(struct ipns *ns, struct in6_addr *dest, uint32_t index_low, + uint64_t index_high); -int ipp_addpool_v4(struct ip_pool *ipp, const struct in_addr *ip, uint8_t cidr); -int ipp_addpool_v6(struct ip_pool *ipp, const struct in6_addr *ip, - uint8_t cidr); +int ipp_addpool_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr); +int ipp_addpool_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr); -int ipp_removepool_v4(struct ip_pool *pool, const struct in_addr *ip); -int ipp_removepool_v6(struct ip_pool *pool, const struct in6_addr *ip); +int ipp_removepool_v4(struct ipns *ns, const struct in_addr *ip); +int ipp_removepool_v6(struct ipns *ns, const struct in6_addr *ip); #ifdef DEBUG void node_to_str(struct radix_node *node, char *buf, uint8_t bits); -void debug_print_trie_v4(struct ip_pool *pool); -void debug_print_trie_v6(struct ip_pool *pool); +void debug_print_trie_v4(struct ipns *ns); +void debug_print_trie_v6(struct ipns *ns); #endif #endif @@ -68,24 +68,27 @@ get_random_bytes(uint8_t *out, size_t len) return i == len; } -uint64_t random_bounded(uint64_t bound) +uint64_t random_u64() { uint64_t ret; + if (!get_random_bytes((uint8_t *)&ret, sizeof(ret))) + fatal("get_random_bytes()"); - if (bound == 0) - return 0; + return ret; +} - if (bound == 1) { - if (!get_random_bytes((uint8_t *)&ret, sizeof(ret))) - fatal("get_random_bytes()"); - return (ret > 0x7FFFFFFFFFFFFFFF) ? 1 : 0; - } +/* Returns a random number [0, bound) (exclusive) */ +uint64_t random_bounded(uint64_t bound) +{ + uint64_t ret, max_mod_bound; + + if (bound < 2) + return 0; - const uint64_t max_mod_bound = (1 + ~bound) % bound; + max_mod_bound = (1 + ~bound) % bound; do { - if (!get_random_bytes((uint8_t *)&ret, sizeof(ret))) - fatal("get_random_bytes()"); + ret = random_u64(); } while (ret < max_mod_bound); return ret % bound; @@ -8,6 +8,7 @@ #include <stdint.h> +uint64_t random_u64(); uint64_t random_bounded(uint64_t bound); #endif diff --git a/wg-dynamic-client.c b/wg-dynamic-client.c index f3e3274..2edb413 100644 --- a/wg-dynamic-client.c +++ b/wg-dynamic-client.c @@ -2,446 +2,288 @@ * * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. */ +#define _POSIX_C_SOURCE 200112L -#include <fcntl.h> -#include <poll.h> +#include <arpa/inet.h> #include <signal.h> -#include <stdio.h> -#include <string.h> -#include <stdlib.h> +#include <stdbool.h> +#include <stdint.h> +#include <sys/socket.h> +#include <sys/time.h> +#include <sys/types.h> +#include <time.h> #include <unistd.h> -#include <arpa/inet.h> -#include <libmnl/libmnl.h> -#include <linux/rtnetlink.h> - #include "common.h" #include "dbg.h" +#include "ipm.h" #include "netlink.h" -struct wg_dynamic_lease { - struct wg_combined_ip ip4; - struct wg_combined_ip ip6; - uint32_t start; - uint32_t leasetime; - struct wg_dynamic_lease *next; -}; - static const char *progname; static const char *wg_interface; +static struct in6_addr well_known; +static struct in6_addr lladdr; + +static struct in_addr ipv4; +static struct in6_addr ipv6; +static bool ipv4_assigned = false, ipv6_assigned = false; + static wg_device *device = NULL; -static int our_fd = -1; -static struct in6_addr our_lladdr = { 0 }; -static struct wg_combined_ip our_gaddr4 = { 0 }; -static struct wg_combined_ip our_gaddr6 = { 0 }; -static struct wg_dynamic_lease our_lease = { 0 }; - -struct mnl_cb_data { - uint32_t ifindex; - struct in6_addr *lladdr; - struct wg_combined_ip *gaddr4; - struct wg_combined_ip *gaddr6; -}; +static int sockfd = -1; + +static volatile sig_atomic_t should_exit = 0; static void usage() { die("usage: %s <wg-interface>\n", progname); } -int data_cb(const struct nlmsghdr *nlh, void *data) -{ - struct nlattr *tb[IFA_MAX + 1] = {}; - struct ifaddrmsg *ifa = mnl_nlmsg_get_payload(nlh); - struct mnl_cb_data *cb_data = (struct mnl_cb_data *)data; - unsigned char *addr; - - if (ifa->ifa_index != cb_data->ifindex) - return MNL_CB_OK; - - mnl_attr_parse(nlh, sizeof(*ifa), data_attr_cb, tb); - - if (!tb[IFA_ADDRESS]) - return MNL_CB_OK; - - addr = mnl_attr_get_payload(tb[IFA_ADDRESS]); - char out[INET6_ADDRSTRLEN]; - inet_ntop(ifa->ifa_family, addr, out, sizeof(out)); - debug("index=%d, family=%d, addr=%s\n", ifa->ifa_index, ifa->ifa_family, - out); - - if (ifa->ifa_scope == RT_SCOPE_LINK) { - if (ifa->ifa_prefixlen != 128) - return MNL_CB_OK; - memcpy(cb_data->lladdr, addr, 16); - } else if (ifa->ifa_scope == RT_SCOPE_UNIVERSE) { - switch (ifa->ifa_family) { - case AF_INET: - cb_data->gaddr4->family = ifa->ifa_family; - memcpy(&cb_data->gaddr4->ip4, addr, 4); - cb_data->gaddr4->cidr = ifa->ifa_prefixlen; - break; - case AF_INET6: - cb_data->gaddr6->family = ifa->ifa_family; - memcpy(&cb_data->gaddr6->ip6, addr, 16); - cb_data->gaddr6->cidr = ifa->ifa_prefixlen; - break; - default: - die("Unknown address family: %u\n", ifa->ifa_family); - } - } - - return MNL_CB_OK; -} - -static void iface_update(uint16_t cmd, uint16_t flags, uint32_t ifindex, - const struct wg_combined_ip *addr) +/* NOTE: do NOT call exit() in here */ +static void cleanup() { - struct mnl_socket *nl; - char buf[MNL_SOCKET_BUFFER_SIZE]; - struct nlmsghdr *nlh; - unsigned int seq, portid; - struct ifaddrmsg *ifaddr; /* linux/if_addr.h */ - int ret; - - nl = mnl_socket_open(NETLINK_ROUTE); - if (nl == NULL) - fatal("mnl_socket_open"); - - if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) - fatal("mnl_socket_bind"); - - portid = mnl_socket_get_portid(nl); - seq = time(NULL); - nlh = mnl_nlmsg_put_header(buf); - nlh->nlmsg_seq = seq; - nlh->nlmsg_type = cmd; - nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | flags; - ifaddr = mnl_nlmsg_put_extra_header(nlh, sizeof(struct ifaddrmsg)); - ifaddr->ifa_family = addr->family; - ifaddr->ifa_prefixlen = addr->cidr; - ifaddr->ifa_scope = RT_SCOPE_UNIVERSE; /* linux/rtnetlink.h */ - ifaddr->ifa_index = ifindex; - mnl_attr_put(nlh, IFA_LOCAL, addr->family == AF_INET ? 4 : 16, &addr); - - if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) - fatal("mnl_socket_sendto"); - - do { - ret = mnl_socket_recvfrom(nl, buf, sizeof(buf)); - if (ret <= MNL_CB_STOP) - break; - ret = mnl_cb_run(buf, ret, seq, portid, NULL, NULL); - } while (ret > 0); - - if (ret == -1) - fatal("mnl_cb_run/mnl_socket_recvfrom"); + if (ipv4_assigned && ipm_deladdr_v4(device->ifindex, &ipv4)) + debug("Failed to cleanup ipv4 address"); + if (ipv6_assigned && ipm_deladdr_v6(device->ifindex, &ipv6)) + debug("Failed to cleanup ipv6 address"); - mnl_socket_close(nl); -} + if (sockfd >= 0) + close(sockfd); -static void iface_remove_addr(uint32_t ifindex, - const struct wg_combined_ip *addr) -{ - char ipstr[INET6_ADDRSTRLEN]; - debug("removing %s/%u from interface %u\n", - inet_ntop(addr->family, &addr, ipstr, sizeof ipstr), addr->cidr, - ifindex); - iface_update(RTM_DELADDR, 0, ifindex, addr); + ipm_free(); + wg_free_device(device); } -static void iface_add_addr(uint32_t ifindex, const struct wg_combined_ip *addr) +static void handler(int signum) { - char ipstr[INET6_ADDRSTRLEN]; - debug("adding %s/%u to interface %u\n", - inet_ntop(addr->family, &addr, ipstr, sizeof ipstr), addr->cidr, - ifindex); - iface_update(RTM_NEWADDR, NLM_F_REPLACE | NLM_F_CREATE, ifindex, addr); + UNUSED(signum); + should_exit = 1; } -static bool get_and_validate_local_addrs(uint32_t ifindex, - struct in6_addr *lladdr, - struct wg_combined_ip *gaddr4, - struct wg_combined_ip *gaddr6) +static void check_signal() { - struct mnl_cb_data cb_data = { - .ifindex = ifindex, - .lladdr = lladdr, - .gaddr4 = gaddr4, - .gaddr6 = gaddr6, - }; - - iface_get_all_addrs(AF_INET, data_cb, &cb_data); - iface_get_all_addrs(AF_INET6, data_cb, &cb_data); - - return !IN6_IS_ADDR_UNSPECIFIED(cb_data.lladdr); + if (should_exit) + exit(EXIT_FAILURE); } -static int try_connect(int *fd) +static int request_ip(struct wg_dynamic_request_ip *rip) { - struct timeval tval = { .tv_sec = 1, .tv_usec = 0 }; - struct sockaddr_in6 our_addr = { + unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE]; + size_t msglen, remaining = 0, off = 0; + struct sockaddr_in6 dstaddr = { .sin6_family = AF_INET6, - .sin6_addr = our_lladdr, + .sin6_addr = well_known, .sin6_port = htons(WG_DYNAMIC_PORT), .sin6_scope_id = device->ifindex, }; - struct sockaddr_in6 their_addr = { + struct sockaddr_in6 srcaddr = { .sin6_family = AF_INET6, + .sin6_addr = lladdr, .sin6_port = htons(WG_DYNAMIC_PORT), .sin6_scope_id = device->ifindex, }; + struct wg_dynamic_request req = { + .cmd = WGKEY_REQUEST_IP, + .version = 1, + .result = rip, + }; + struct timeval timeout = { .tv_sec = 30 }; + ssize_t ret; + int val = 1; - *fd = socket(AF_INET6, SOCK_STREAM, 0); - if (*fd < 0) + sockfd = socket(AF_INET6, SOCK_STREAM, 0); + if (sockfd < 0) fatal("Creating a socket failed"); - if (setsockopt(*fd, SOL_SOCKET, SO_RCVTIMEO, &tval, sizeof tval) == -1) - fatal("Setting socket option failed"); + if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val)) + fatal("setsockopt(SO_REUSEADDR)"); + + if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, + sizeof timeout)) + fatal("setsockopt(SO_RCVTIMEO)"); + + if (setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, + sizeof timeout)) + fatal("setsockopt(SO_SNDTIMEO)"); - if (bind(*fd, (struct sockaddr *)&our_addr, sizeof(our_addr))) + if (bind(sockfd, (struct sockaddr *)&srcaddr, sizeof(srcaddr))) fatal("Binding socket failed"); - if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &their_addr.sin6_addr) != 1) - fatal("inet_pton()"); + if (connect(sockfd, (struct sockaddr *)&dstaddr, sizeof(dstaddr))) + fatal("connect()"); + + if (ipv4_assigned) { + memcpy(&rip->ipv4, &ipv4, sizeof rip->ipv4); + rip->has_ipv4 = true; + } + if (ipv6_assigned) { + memcpy(&rip->ipv6, &ipv6, sizeof rip->ipv6); + rip->has_ipv6 = true; + } + + msglen = serialize_request_ip(true, (char *)buf, RECV_BUFSIZE, rip); + do { + ssize_t written = write(sockfd, buf + off, msglen - off); + if (written == -1) { + if (errno == EINTR) { + check_signal(); + continue; + } + + fatal("write()"); + } - if (connect(*fd, (struct sockaddr *)&their_addr, - sizeof(struct sockaddr_in6))) { - char out[INET6_ADDRSTRLEN]; + off += written; + } while (off < msglen); - if (!inet_ntop(their_addr.sin6_family, &their_addr.sin6_addr, - out, sizeof out)) - fatal("inet_ntop()"); - debug("Connecting to [%s]:%u failed: %s\n", out, - ntohs(their_addr.sin6_port), strerror(errno)); + memset(rip, 0, sizeof *rip); + + while ((ret = handle_request(sockfd, &req, buf, &remaining)) <= 0) { + if (ret == 0) { + check_signal(); + continue; + } + + if (close(sockfd)) + debug("Failed to close socket: %s\n", strerror(errno)); - if (close(*fd)) - debug("Closing socket failed: %s\n", strerror(errno)); - *fd = -1; return -1; } - return 0; -} + if (remaining > 0) + log_err("Warning: discarding %zu extra bytes sent by the server\n", + remaining); -static void request_ip(int fd, const struct wg_dynamic_lease *lease) -{ - unsigned char buf[MAX_RESPONSE_SIZE + 1]; - char addrstr[INET6_ADDRSTRLEN]; - size_t msglen; - - msglen = 0; - msglen += print_to_buf((char *)buf, sizeof buf, msglen, "%s=%d\n", - WG_DYNAMIC_KEY[WGKEY_REQUEST_IP], 1); - - if (lease && lease->ip4.ip4.s_addr) { - if (!inet_ntop(AF_INET, &lease->ip4.ip4, addrstr, - sizeof addrstr)) - fatal("inet_ntop()"); - msglen += print_to_buf((char *)buf, sizeof buf, msglen, - "ipv4=%s/32\n", addrstr); + if (rip->wg_errno) + return -1; + + if (!ipv4_assigned || memcmp(&ipv4, &rip->ipv4, sizeof ipv4)) { + if (ipv4_assigned && ipm_deladdr_v4(device->ifindex, &ipv4)) + fatal("ipm_deladdr_v4()"); + + memcpy(&ipv4, &rip->ipv4, sizeof ipv4); + if (ipm_newaddr_v4(device->ifindex, &ipv4)) + fatal("ipm_newaddr_v4()"); + ipv4_assigned = true; } - if (lease && !IN6_IS_ADDR_UNSPECIFIED(&lease->ip6.ip6)) { - if (!inet_ntop(AF_INET6, &lease->ip6.ip6, addrstr, - sizeof addrstr)) - fatal("inet_ntop()"); - msglen += print_to_buf((char *)buf, sizeof buf, msglen, - "ipv6=%s/128\n", addrstr); + + if (!ipv6_assigned || memcmp(&ipv6, &rip->ipv6, sizeof ipv6)) { + if (ipv6_assigned && ipm_deladdr_v6(device->ifindex, &ipv6)) + fatal("ipm_deladdr_v6()"); + + memcpy(&ipv6, &rip->ipv6, sizeof ipv6); + if (ipm_newaddr_v6(device->ifindex, &ipv6)) + fatal("ipm_newaddr_v6()"); + ipv6_assigned = true; } - /* nmsglen += print_to_buf((char *)buf, sizeof buf, msglen, - "leasetime=%u\n", fixme); */ - msglen += print_to_buf((char *)buf, sizeof buf, msglen, "\n"); + if (close(sockfd)) + debug("Failed to close socket: %s\n", strerror(errno)); - send_message(fd, buf, &msglen); + return 0; } -static uint32_t time_until_refresh(uint32_t now, struct wg_dynamic_lease *lease) +static void setup() { - uint32_t refresh_at; + struct sigaction sa = { .sa_handler = handler, .sa_flags = 0 }; + struct wg_combined_ip ip; + int ret; - if (lease->leasetime == 0) - return 0; - refresh_at = lease->start + (lease->leasetime * 8) / 10; + if (atexit(cleanup)) + die("Failed to set exit function\n"); - if (refresh_at < now) - return 0; - return refresh_at - now; -} + sigemptyset(&sa.sa_mask); + if (sigaction(SIGINT, &sa, NULL) == -1) + fatal("sigaction()"); -static int handle_received_lease(const struct wg_dynamic_request *req) -{ - uint32_t ret; - struct wg_dynamic_attr *attr; - struct wg_dynamic_lease *lease = &our_lease; - uint32_t now = current_time(); - uint32_t lease_start = 0; - uint32_t curleasetime = lease->start + lease->leasetime; - - attr = req->first; - while (attr) { - switch (attr->key) { - case WGKEY_IPV4: - memcpy(&lease->ip4, attr->value, - sizeof(struct wg_combined_ip)); - break; - case WGKEY_IPV6: - memcpy(&lease->ip6, attr->value, - sizeof(struct wg_combined_ip)); - break; - case WGKEY_LEASESTART: - memcpy(&lease_start, attr->value, sizeof(uint32_t)); - break; - case WGKEY_LEASETIME: - memcpy(&lease->leasetime, attr->value, - sizeof(uint32_t)); - break; - case WGKEY_ERRNO: - memcpy(&ret, attr->value, sizeof(uint32_t)); - if (ret) { - debug("Request IP failed with %ud from server\n", - ret); - return -ret; - } - break; - case WGKEY_ERRMSG: - /* TODO: do something with the error message */ - break; - default: - debug("Ignoring invalid attribute for request_ip: %d\n", - attr->key); - } - attr = attr->next; - } - - if (lease->leasetime == 0 || (lease->ip4.ip4.s_addr == 0 && - IN6_IS_ADDR_UNSPECIFIED(&lease->ip6.ip6))) - return -EINVAL; + if (wg_get_device(&device, wg_interface)) + fatal("Unable to access interface %s", wg_interface); - if (abs(now - lease_start) < 15) - lease->start = lease_start; - else - lease->start = now; + if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known) != 1) + fatal("inet_pton()"); - debug("Replacing lease %u -> %u\n", curleasetime, - lease->start + lease->leasetime); + ipm_init(); - return 0; -} + ret = ipm_getlladdr(device->ifindex, &ip); + if (ret == -1) + fatal("ipm_getlladdr()"); -static void cleanup() -{ - wg_free_device(device); - if (our_fd != -1 && close(our_fd)) - debug("Failed to close fd %d\n", our_fd); -} + if (ret == -2 || ip.family != AF_INET6) + die("%s needs to be assigned an IPv6 link local address\n", + wg_interface); -static bool handle_error(int fd, int ret) -{ - UNUSED(fd); - UNUSED(ret); + if (ret == -3) + die("Interface must not have multiple link-local addresses assigned\n"); - debug("Unable to parse response: %s\n", strerror(ret)); + if (ip.cidr != 128) + die("Link-local address must have a CIDR of 128\n"); - return true; + memcpy(&lladdr, &ip, 16); } -static void maybe_update_iface() +static void xnanosleep(time_t duration) { - if (memcmp(&our_gaddr4, &our_lease.ip4, sizeof our_gaddr4) || - our_gaddr4.cidr != our_lease.ip4.cidr) { - if (our_gaddr4.ip4.s_addr) - iface_remove_addr(device->ifindex, &our_gaddr4); - iface_add_addr(device->ifindex, &our_lease.ip4); - memcpy(&our_gaddr4, &our_lease.ip4, sizeof our_gaddr4); - } - if (memcmp(&our_gaddr6, &our_lease.ip6, sizeof our_gaddr6) || - our_gaddr6.cidr != our_lease.ip6.cidr) { - if (!IN6_IS_ADDR_UNSPECIFIED(&our_gaddr6.ip6)) - iface_remove_addr(device->ifindex, &our_gaddr6); - iface_add_addr(device->ifindex, &our_lease.ip6); - memcpy(&our_gaddr6, &our_lease.ip6, sizeof our_gaddr6); + struct timespec rem, timeout = { .tv_sec = duration }; + int ret; + + while ((ret = clock_nanosleep(CLOCK_BOOTTIME, 0, &timeout, &rem))) { + if (ret == EINTR) { + check_signal(); + memcpy(&timeout, &rem, sizeof timeout); + continue; + } + + die("clock_nanosleep(): %s\n", strerror(ret)); } } -static bool handle_response(int fd, struct wg_dynamic_request *req) +static void loop() { - UNUSED(fd); - -#if 0 - printf("Recieved response of type %s.\n", WG_DYNAMIC_KEY[req->cmd]); - struct wg_dynamic_attr *cur = req->first; - while (cur) { - printf(" with attr %s.\n", WG_DYNAMIC_KEY[cur->key]); - cur = cur->next; + struct wg_dynamic_request_ip rip = { 0 }; + struct timespec tsend, trecv; + time_t expires, timeout; + + if (clock_gettime(CLOCK_REALTIME, &tsend)) + fatal("clock_gettime(CLOCK_REALTIME)"); + + if (request_ip(&rip)) { + /* TODO: implement some sort of exponential backoff */ + debug("Server communication error, trying again in 30s\n"); + xnanosleep(30); + return; } -#endif - - switch (req->cmd) { - case WGKEY_REQUEST_IP: - if (handle_received_lease(req) == 0) - maybe_update_iface(); - break; - default: - debug("Unknown command: %d\n", req->cmd); - return true; + + if (clock_gettime(CLOCK_REALTIME, &trecv)) + fatal("clock_gettime(CLOCK_REALTIME)"); + + if (tsend.tv_sec < rip.start + 5 || rip.start > trecv.tv_sec + 5) + expires = tsend.tv_sec + rip.leasetime; + else + expires = MIN(rip.leasetime, trecv.tv_sec) + rip.leasetime; + + if (expires <= trecv.tv_sec) { + log_err("Warning: lease we tried to aquire already expired\n"); + return; } - return true; + /* TODO: implement random jitter */ + timeout = (expires - trecv.tv_sec); + timeout -= MIN(30, timeout * 0.5); + + debug("Sleeping for %zus\n", timeout); + xnanosleep(timeout); } -int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) +int main(int argc, char *argv[]) { - int *fd = &our_fd; - struct wg_dynamic_request req = { 0 }; - progname = argv[0]; if (argc != 2) usage(); wg_interface = argv[1]; + setup(); - if (wg_get_device(&device, wg_interface)) - fatal("Unable to access interface %s", wg_interface); - - if (atexit(cleanup)) - die("Failed to set exit function\n"); - - if (!get_and_validate_local_addrs(device->ifindex, &our_lladdr, - &our_gaddr4, &our_gaddr6)) - die("%s needs to have an IPv6 link local address with prefixlen 128 assigned\n", - wg_interface); - // TODO: verify that we have a peer with an allowed-ips including fe80::/128 - - char lladr_str[INET6_ADDRSTRLEN]; - debug("%s: %s\n", wg_interface, - inet_ntop(AF_INET6, &our_lladdr, lladr_str, sizeof lladr_str)); - - /* If we have an address configured, let's assume it's from a - * lease in order to get renewal done. */ - if (our_gaddr4.ip4.s_addr || - !IN6_IS_ADDR_UNSPECIFIED(&our_gaddr6.ip6)) { - our_lease.start = current_time(); - our_lease.leasetime = 15; - memcpy(&our_lease.ip4, &our_gaddr4, - sizeof(struct wg_combined_ip)); - memcpy(&our_lease.ip6, &our_gaddr6, - sizeof(struct wg_combined_ip)); - } - - while (1) { - sleep(time_until_refresh(current_time(), &our_lease)); - - if (*fd == -1 && try_connect(fd)) { - sleep(1); - continue; - } - - request_ip(*fd, &our_lease); - - while (!handle_request(&req, handle_response, handle_error)) - ; - close_connection(&req); - } + while (1) + loop(); return 0; } diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c index 27f4054..feb9656 100644 --- a/wg-dynamic-server.c +++ b/wg-dynamic-server.c @@ -23,6 +23,7 @@ #include "common.h" #include "dbg.h" +#include "ipm.h" #include "khash.h" #include "lease.h" #include "netlink.h" @@ -32,7 +33,6 @@ static const char *wg_interface = NULL; static struct in6_addr well_known; static wg_device *device = NULL; -static struct wg_dynamic_request requests[MAX_CONNECTIONS] = { 0 }; static uint32_t leasetime = WG_DYNAMIC_DEFAULT_LEASETIME; static int sockfd = -1; @@ -42,60 +42,22 @@ static struct mnl_socket *nlsock = NULL; KHASH_MAP_INIT_INT64(allowedht, wg_key *) khash_t(allowedht) * allowedips_ht; -struct mnl_cb_data { - uint32_t ifindex; - bool valid_ip_found; +struct wg_dynamic_connection { + struct wg_dynamic_request req; + int fd; + wg_key pubkey; + struct in6_addr lladdr; + unsigned char *outbuf; + size_t buflen; }; +static struct wg_dynamic_connection connections[MAX_CONNECTIONS] = { 0 }; + static void usage() { die("usage: %s [--leasetime <leasetime>] <wg-interface>\n", progname); } -static int data_cb(const struct nlmsghdr *nlh, void *data) -{ - struct nlattr *tb[IFA_MAX + 1] = {}; - struct ifaddrmsg *ifa = mnl_nlmsg_get_payload(nlh); - struct mnl_cb_data *cb_data = (struct mnl_cb_data *)data; - unsigned char *addr; - - if (ifa->ifa_index != cb_data->ifindex) - return MNL_CB_OK; - - if (ifa->ifa_scope != RT_SCOPE_LINK) - return MNL_CB_OK; - - mnl_attr_parse(nlh, sizeof(*ifa), data_attr_cb, tb); - - if (!tb[IFA_ADDRESS]) - return MNL_CB_OK; - - addr = mnl_attr_get_payload(tb[IFA_ADDRESS]); - char out[INET6_ADDRSTRLEN]; - inet_ntop(ifa->ifa_family, addr, out, sizeof(out)); - debug("index=%d, family=%d, addr=%s\n", ifa->ifa_index, ifa->ifa_family, - out); - - if (ifa->ifa_prefixlen != 64 || memcmp(addr, well_known.s6_addr, 16)) - return MNL_CB_OK; - - cb_data->valid_ip_found = true; - - return MNL_CB_OK; -} - -static bool validate_link_local_ip(uint32_t ifindex) -{ - struct mnl_cb_data cb_data = { - .ifindex = ifindex, - .valid_ip_found = false, - }; - - iface_get_all_addrs(AF_INET6, data_cb, &cb_data); - - return cb_data.valid_ip_found; -} - static bool valid_peer_found(wg_device *device) { wg_peer *peer; @@ -167,7 +129,7 @@ static wg_key *addr_to_pubkey(struct sockaddr_storage *addr) return NULL; } -static int accept_connection(int sockfd, wg_key *dest_pubkey, +static int accept_connection(wg_key *dest_pubkey, struct in6_addr *dest_lladdr) { int fd; @@ -227,101 +189,113 @@ static int accept_connection(int sockfd, wg_key *dest_pubkey, return fd; } -static bool send_error(struct wg_dynamic_request *req, int error) +static bool send_message(struct wg_dynamic_connection *con, + const unsigned char *buf, size_t len) { - char buf[MAX_RESPONSE_SIZE]; - size_t msglen = 0; + size_t offset = 0; - print_to_buf(buf, sizeof buf, &msglen, "errno=%d\nerrmsg=%s\n\n", error, - WG_DYNAMIC_ERR[error]); - - return send_message(req, buf, msglen); -} + while (1) { + ssize_t written = write(con->fd, buf + offset, len - offset); + if (written < 0) { + if (errno == EWOULDBLOCK || errno == EAGAIN) + break; -static size_t serialize_lease(char *buf, size_t len, - const struct wg_dynamic_lease *lease) -{ - char addrbuf[INET6_ADDRSTRLEN]; - size_t off = 0; + if (errno == EINTR) + continue; - if (lease->ipv4.s_addr) { - if (!inet_ntop(AF_INET, &lease->ipv4, addrbuf, sizeof addrbuf)) - fatal("inet_ntop()"); + debug("Writing to socket %d failed: %s\n", con->fd, + strerror(errno)); + return false; + } - print_to_buf(buf, len, &off, "ipv4=%s/%d\n", addrbuf, 32); + offset += written; + if (offset == len) + return true; } - if (!IN6_IS_ADDR_UNSPECIFIED(&lease->ipv6)) { - if (!inet_ntop(AF_INET6, &lease->ipv6, addrbuf, sizeof addrbuf)) - fatal("inet_ntop()"); - - print_to_buf(buf, len, &off, "ipv6=%s/%d\n", addrbuf, 128); + debug("Socket %d blocking on write with %lu bytes left, postponing\n", + con->fd, len - offset); + + if (!con->outbuf) { + con->buflen = len - offset; + con->outbuf = malloc(con->buflen); + if (!con->outbuf) + fatal("malloc()"); + memcpy(con->outbuf, buf + offset, con->buflen); + } else { + con->buflen = len - offset; + memmove(con->outbuf, buf + offset, con->buflen); } - print_to_buf(buf, len, &off, "leasestart=%u\nleasetime=%u\nerrno=0\n\n", - lease->start_real, lease->leasetime); - - return off; + return true; } -static int response_request_ip(struct wg_dynamic_attr *cur, wg_key pubkey, - const struct in6_addr *lladdr, - struct wg_dynamic_lease **lease) +void close_connection(struct wg_dynamic_connection *con) { - struct in_addr *ipv4 = NULL; - struct in6_addr *ipv6 = NULL; - - while (cur) { - switch (cur->key) { - case WGKEY_IPV4: - ipv4 = &((struct wg_combined_ip *)cur->value)->ip4; - break; - case WGKEY_IPV6: - ipv6 = &((struct wg_combined_ip *)cur->value)->ip6; - break; - case WGKEY_LEASETIME: - leasetime = *(uint32_t *)cur->value; - break; - default: - debug("Ignoring invalid attribute for request_ip: %d\n", - cur->key); - } - cur = cur->next; - } + free_wg_dynamic_request(&con->req); - *lease = set_lease(wg_interface, pubkey, leasetime, lladdr, ipv4, ipv6); - if (!*lease) - return E_IP_UNAVAIL; + if (close(con->fd)) + debug("Failed to close socket\n"); - return E_NO_ERROR; + con->fd = -1; + memset(con->pubkey, 0, sizeof con->pubkey); + free(con->outbuf); + con->outbuf = NULL; + con->buflen = 0; } -static bool send_response(struct wg_dynamic_request *req) +static bool send_response(struct wg_dynamic_connection *con) { char buf[MAX_RESPONSE_SIZE]; - struct wg_dynamic_attr *cur = req->first; - struct wg_dynamic_lease *lease; size_t msglen; - int ret; - switch (req->cmd) { - case WGKEY_REQUEST_IP: - ret = response_request_ip(cur, req->pubkey, &req->lladdr, - &lease); - if (ret) - break; + switch (con->req.cmd) { + case WGKEY_REQUEST_IP:; + struct wg_dynamic_request_ip *rip = con->req.result; + struct in6_addr *lladdr = &con->lladdr; + struct in_addr *ip4 = rip->has_ipv4 ? &rip->ipv4 : NULL; + struct in6_addr *ip6 = rip->has_ipv6 ? &rip->ipv6 : NULL; + struct wg_dynamic_lease *lease; + struct wg_dynamic_request_ip ans = { 0 }; + + lease = set_lease(wg_interface, con->pubkey, leasetime, lladdr, ip4, ip6); + if (lease) { + memcpy(&ans.ipv4, &lease->ipv4, sizeof ans.ipv4); + memcpy(&ans.ipv6, &lease->ipv6, sizeof ans.ipv6); + ans.has_ipv4 = ans.has_ipv6 = true; + ans.start = lease->start_real; + ans.leasetime = lease->leasetime; + } else { + ans.wg_errno = E_IP_UNAVAIL; + } - msglen = serialize_lease(buf, sizeof buf, lease); + msglen = serialize_request_ip(false, buf, sizeof buf, &ans); break; default: - debug("Unknown command: %d\n", req->cmd); + debug("Unknown command: %d\n", con->req.cmd); BUG(); } - if (ret) - return send_error(req, ret); + return send_message(con, (unsigned char *)buf, msglen); +} - return send_message(req, buf, msglen); +static void handle_client(struct wg_dynamic_connection *con) +{ + unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE]; + size_t rem = 0; + ssize_t ret; + + while ((ret = handle_request(con->fd, &con->req, buf, &rem)) > 0) { + if (!send_response(con)) { + close_connection(con); + break; + } + + free_wg_dynamic_request(&con->req); + } + + if (ret < 0) + close_connection(con); } static void setup_sockets() @@ -391,10 +365,10 @@ static void cleanup() close(epollfd); for (int i = 0; i < MAX_CONNECTIONS; ++i) { - if (requests[i].fd < 0) + if (connections[i].fd < 0) continue; - close_connection(&requests[i]); + close_connection(&connections[i]); } } @@ -427,24 +401,39 @@ static void init_leaess_from_peers() static void setup() { + struct wg_combined_ip ip; + int ret; + if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known) != 1) fatal("inet_pton()"); allowedips_ht = kh_init(allowedht); for (int i = 0; i < MAX_CONNECTIONS; ++i) - requests[i].fd = -1; + connections[i].fd = -1; if (atexit(cleanup)) die("Failed to set exit function\n"); rebuild_allowedips_ht(); - if (!validate_link_local_ip(device->ifindex)) - // TODO: assign IP instead? + ipm_init(); + ret = ipm_getlladdr(device->ifindex, &ip); + if (ret == -1) + fatal("ipm_getlladdr()"); + if (ret == -2) + die("Interface must not have multiple link-local addresses assigned\n"); + ipm_free(); + + if (ret == -1 || ip.family != AF_INET6 || + memcmp(&ip.ip6, well_known.s6_addr, 16)) + /* TODO: assign IP instead? */ die("%s needs to have %s assigned\n", wg_interface, WG_DYNAMIC_ADDR); + if (ip.cidr != 64) + die("Link-local address must have a CIDR of 64\n"); + if (!valid_peer_found(device)) die("%s has no peers with link-local allowedips\n", wg_interface); @@ -460,48 +449,46 @@ static int get_avail_request() if (nfds >= MAX_CONNECTIONS) return -1; - if (requests[nfds].fd < 0) + if (connections[nfds].fd < 0) return nfds; } } -static void accept_incoming(int sockfd, int epollfd, - struct wg_dynamic_request *requests) +static void accept_incoming() { int n, fd; struct epoll_event ev; while ((n = get_avail_request()) >= 0) { - fd = accept_connection(sockfd, &requests[n].pubkey, - &requests[n].lladdr); + fd = accept_connection(&connections[n].pubkey, &connections[n].lladdr); if (fd < 0) { if (fd == -ENOENT) { debug("Failed to match IP to pubkey\n"); continue; - } else if (fd != -EAGAIN && fd != -EWOULDBLOCK) { - debug("Failed to accept connection: %s\n", - strerror(-fd)); - continue; + } else if (fd == -EAGAIN || fd == -EWOULDBLOCK) { + return; } - break; + debug("Failed to accept connection: %s\n", + strerror(-fd)); + continue; } ev.events = EPOLLIN | EPOLLET; - ev.data.ptr = &requests[n]; + ev.data.ptr = &connections[n]; if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) fatal("epoll_ctl()"); - requests[n].fd = fd; + connections[n].fd = fd; } } static void handle_event(void *ptr, uint32_t events) { - struct wg_dynamic_request *req; + struct wg_dynamic_connection *con; if (ptr == &sockfd) { - accept_incoming(sockfd, epollfd, requests); + accept_incoming(); return; } @@ -510,15 +497,14 @@ static void handle_event(void *ptr, uint32_t events) return; } - req = (struct wg_dynamic_request *)ptr; + con = (struct wg_dynamic_connection *)ptr; if (events & EPOLLIN) { - if (handle_request(req, send_response, send_error)) - close_connection(req); + handle_client(con); } if (events & EPOLLOUT) { - if (send_message(req, req->buf, req->buflen)) - close_connection(req); + if (!send_message(con, con->outbuf, con->buflen)) + close_connection(con); } } |