aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c373
-rw-r--r--drivers/vhost/vhost.c80
-rw-r--r--drivers/vhost/vhost.h11
3 files changed, 338 insertions, 126 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 686dc670fd29..4e656f89cb22 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -78,6 +78,10 @@ enum {
};
enum {
+ VHOST_NET_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
+};
+
+enum {
VHOST_NET_VQ_RX = 0,
VHOST_NET_VQ_TX = 1,
VHOST_NET_VQ_MAX = 2,
@@ -94,7 +98,7 @@ struct vhost_net_ubuf_ref {
struct vhost_virtqueue *vq;
};
-#define VHOST_RX_BATCH 64
+#define VHOST_NET_BATCH 64
struct vhost_net_buf {
void **queue;
int tail;
@@ -168,7 +172,7 @@ static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
rxq->head = 0;
rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue,
- VHOST_RX_BATCH);
+ VHOST_NET_BATCH);
return rxq->tail;
}
@@ -396,13 +400,10 @@ static inline unsigned long busy_clock(void)
return local_clock() >> 10;
}
-static bool vhost_can_busy_poll(struct vhost_dev *dev,
- unsigned long endtime)
+static bool vhost_can_busy_poll(unsigned long endtime)
{
- return likely(!need_resched()) &&
- likely(!time_after(busy_clock(), endtime)) &&
- likely(!signal_pending(current)) &&
- !vhost_has_work(dev);
+ return likely(!need_resched() && !time_after(busy_clock(), endtime) &&
+ !signal_pending(current));
}
static void vhost_net_disable_vq(struct vhost_net *n,
@@ -431,21 +432,42 @@ static int vhost_net_enable_vq(struct vhost_net *n,
return vhost_poll_start(poll, sock->file);
}
+static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq)
+{
+ struct vhost_virtqueue *vq = &nvq->vq;
+ struct vhost_dev *dev = vq->dev;
+
+ if (!nvq->done_idx)
+ return;
+
+ vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
+ nvq->done_idx = 0;
+}
+
static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
- struct vhost_virtqueue *vq,
- struct iovec iov[], unsigned int iov_size,
- unsigned int *out_num, unsigned int *in_num)
+ struct vhost_net_virtqueue *nvq,
+ unsigned int *out_num, unsigned int *in_num,
+ bool *busyloop_intr)
{
+ struct vhost_virtqueue *vq = &nvq->vq;
unsigned long uninitialized_var(endtime);
int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
out_num, in_num, NULL, NULL);
if (r == vq->num && vq->busyloop_timeout) {
+ if (!vhost_sock_zcopy(vq->private_data))
+ vhost_net_signal_used(nvq);
preempt_disable();
endtime = busy_clock() + vq->busyloop_timeout;
- while (vhost_can_busy_poll(vq->dev, endtime) &&
- vhost_vq_avail_empty(vq->dev, vq))
+ while (vhost_can_busy_poll(endtime)) {
+ if (vhost_has_work(vq->dev)) {
+ *busyloop_intr = true;
+ break;
+ }
+ if (!vhost_vq_avail_empty(vq->dev, vq))
+ break;
cpu_relax();
+ }
preempt_enable();
r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
out_num, in_num, NULL, NULL);
@@ -463,9 +485,62 @@ static bool vhost_exceeds_maxpend(struct vhost_net *net)
min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2);
}
-/* Expects to be always run from workqueue - which acts as
- * read-size critical section for our kind of RCU. */
-static void handle_tx(struct vhost_net *net)
+static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter,
+ size_t hdr_size, int out)
+{
+ /* Skip header. TODO: support TSO. */
+ size_t len = iov_length(vq->iov, out);
+
+ iov_iter_init(iter, WRITE, vq->iov, out, len);
+ iov_iter_advance(iter, hdr_size);
+
+ return iov_iter_count(iter);
+}
+
+static bool vhost_exceeds_weight(int pkts, int total_len)
+{
+ return total_len >= VHOST_NET_WEIGHT ||
+ pkts >= VHOST_NET_PKT_WEIGHT;
+}
+
+static int get_tx_bufs(struct vhost_net *net,
+ struct vhost_net_virtqueue *nvq,
+ struct msghdr *msg,
+ unsigned int *out, unsigned int *in,
+ size_t *len, bool *busyloop_intr)
+{
+ struct vhost_virtqueue *vq = &nvq->vq;
+ int ret;
+
+ ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr);
+
+ if (ret < 0 || ret == vq->num)
+ return ret;
+
+ if (*in) {
+ vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n",
+ *out, *in);
+ return -EFAULT;
+ }
+
+ /* Sanity check */
+ *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out);
+ if (*len == 0) {
+ vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n",
+ *len, nvq->vhost_hlen);
+ return -EFAULT;
+ }
+
+ return ret;
+}
+
+static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len)
+{
+ return total_len < VHOST_NET_WEIGHT &&
+ !vhost_vq_avail_empty(vq->dev, vq);
+}
+
+static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
{
struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
struct vhost_virtqueue *vq = &nvq->vq;
@@ -480,67 +555,103 @@ static void handle_tx(struct vhost_net *net)
};
size_t len, total_len = 0;
int err;
- size_t hdr_size;
- struct socket *sock;
- struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
- bool zcopy, zcopy_used;
int sent_pkts = 0;
- mutex_lock(&vq->mutex);
- sock = vq->private_data;
- if (!sock)
- goto out;
+ for (;;) {
+ bool busyloop_intr = false;
- if (!vq_iotlb_prefetch(vq))
- goto out;
+ head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
+ &busyloop_intr);
+ /* On error, stop handling until the next kick. */
+ if (unlikely(head < 0))
+ break;
+ /* Nothing new? Wait for eventfd to tell us they refilled. */
+ if (head == vq->num) {
+ if (unlikely(busyloop_intr)) {
+ vhost_poll_queue(&vq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev,
+ vq))) {
+ vhost_disable_notify(&net->dev, vq);
+ continue;
+ }
+ break;
+ }
- vhost_disable_notify(&net->dev, vq);
- vhost_net_disable_vq(net, vq);
+ vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
+ vq->heads[nvq->done_idx].len = 0;
- hdr_size = nvq->vhost_hlen;
- zcopy = nvq->ubufs;
+ total_len += len;
+ if (tx_can_batch(vq, total_len))
+ msg.msg_flags |= MSG_MORE;
+ else
+ msg.msg_flags &= ~MSG_MORE;
+
+ /* TODO: Check specific error and bomb out unless ENOBUFS? */
+ err = sock->ops->sendmsg(sock, &msg, len);
+ if (unlikely(err < 0)) {
+ vhost_discard_vq_desc(vq, 1);
+ vhost_net_enable_vq(net, vq);
+ break;
+ }
+ if (err != len)
+ pr_debug("Truncated TX packet: len %d != %zd\n",
+ err, len);
+ if (++nvq->done_idx >= VHOST_NET_BATCH)
+ vhost_net_signal_used(nvq);
+ if (vhost_exceeds_weight(++sent_pkts, total_len)) {
+ vhost_poll_queue(&vq->poll);
+ break;
+ }
+ }
+
+ vhost_net_signal_used(nvq);
+}
+
+static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
+{
+ struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+ struct vhost_virtqueue *vq = &nvq->vq;
+ unsigned out, in;
+ int head;
+ struct msghdr msg = {
+ .msg_name = NULL,
+ .msg_namelen = 0,
+ .msg_control = NULL,
+ .msg_controllen = 0,
+ .msg_flags = MSG_DONTWAIT,
+ };
+ size_t len, total_len = 0;
+ int err;
+ struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
+ bool zcopy_used;
+ int sent_pkts = 0;
for (;;) {
- /* Release DMAs done buffers first */
- if (zcopy)
- vhost_zerocopy_signal_used(net, vq);
+ bool busyloop_intr;
+ /* Release DMAs done buffers first */
+ vhost_zerocopy_signal_used(net, vq);
- head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
- ARRAY_SIZE(vq->iov),
- &out, &in);
+ busyloop_intr = false;
+ head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
+ &busyloop_intr);
/* On error, stop handling until the next kick. */
if (unlikely(head < 0))
break;
/* Nothing new? Wait for eventfd to tell us they refilled. */
if (head == vq->num) {
- if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+ if (unlikely(busyloop_intr)) {
+ vhost_poll_queue(&vq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
vhost_disable_notify(&net->dev, vq);
continue;
}
break;
}
- if (in) {
- vq_err(vq, "Unexpected descriptor format for TX: "
- "out %d, int %d\n", out, in);
- break;
- }
- /* Skip header. TODO: support TSO. */
- len = iov_length(vq->iov, out);
- iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
- iov_iter_advance(&msg.msg_iter, hdr_size);
- /* Sanity check */
- if (!msg_data_left(&msg)) {
- vq_err(vq, "Unexpected header len for TX: "
- "%zd expected %zd\n",
- len, hdr_size);
- break;
- }
- len = msg_data_left(&msg);
- zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
- && !vhost_exceeds_maxpend(net)
- && vhost_net_tx_select_zcopy(net);
+ zcopy_used = len >= VHOST_GOODCOPY_LEN
+ && !vhost_exceeds_maxpend(net)
+ && vhost_net_tx_select_zcopy(net);
/* use msg_control to pass vhost zerocopy ubuf info to skb */
if (zcopy_used) {
@@ -562,10 +673,8 @@ static void handle_tx(struct vhost_net *net)
msg.msg_control = NULL;
ubufs = NULL;
}
-
total_len += len;
- if (total_len < VHOST_NET_WEIGHT &&
- !vhost_vq_avail_empty(&net->dev, vq) &&
+ if (tx_can_batch(vq, total_len) &&
likely(!vhost_exceeds_maxpend(net))) {
msg.msg_flags |= MSG_MORE;
} else {
@@ -592,12 +701,37 @@ static void handle_tx(struct vhost_net *net)
else
vhost_zerocopy_signal_used(net, vq);
vhost_net_tx_packet(net);
- if (unlikely(total_len >= VHOST_NET_WEIGHT) ||
- unlikely(++sent_pkts >= VHOST_NET_PKT_WEIGHT)) {
+ if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) {
vhost_poll_queue(&vq->poll);
break;
}
}
+}
+
+/* Expects to be always run from workqueue - which acts as
+ * read-size critical section for our kind of RCU. */
+static void handle_tx(struct vhost_net *net)
+{
+ struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+ struct vhost_virtqueue *vq = &nvq->vq;
+ struct socket *sock;
+
+ mutex_lock(&vq->mutex);
+ sock = vq->private_data;
+ if (!sock)
+ goto out;
+
+ if (!vq_iotlb_prefetch(vq))
+ goto out;
+
+ vhost_disable_notify(&net->dev, vq);
+ vhost_net_disable_vq(net, vq);
+
+ if (vhost_sock_zcopy(sock))
+ handle_tx_zerocopy(net, sock);
+ else
+ handle_tx_copy(net, sock);
+
out:
mutex_unlock(&vq->mutex);
}
@@ -633,53 +767,50 @@ static int sk_has_rx_data(struct sock *sk)
return skb_queue_empty(&sk->sk_receive_queue);
}
-static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq)
+static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
+ bool *busyloop_intr)
{
- struct vhost_virtqueue *vq = &nvq->vq;
- struct vhost_dev *dev = vq->dev;
-
- if (!nvq->done_idx)
- return;
-
- vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
- nvq->done_idx = 0;
-}
-
-static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
-{
- struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
- struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
- struct vhost_virtqueue *vq = &nvq->vq;
+ struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
+ struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
+ struct vhost_virtqueue *rvq = &rnvq->vq;
+ struct vhost_virtqueue *tvq = &tnvq->vq;
unsigned long uninitialized_var(endtime);
- int len = peek_head_len(rvq, sk);
+ int len = peek_head_len(rnvq, sk);
- if (!len && vq->busyloop_timeout) {
+ if (!len && tvq->busyloop_timeout) {
/* Flush batched heads first */
- vhost_rx_signal_used(rvq);
+ vhost_net_signal_used(rnvq);
/* Both tx vq and rx socket were polled here */
- mutex_lock_nested(&vq->mutex, 1);
- vhost_disable_notify(&net->dev, vq);
+ mutex_lock_nested(&tvq->mutex, 1);
+ vhost_disable_notify(&net->dev, tvq);
preempt_disable();
- endtime = busy_clock() + vq->busyloop_timeout;
+ endtime = busy_clock() + tvq->busyloop_timeout;
- while (vhost_can_busy_poll(&net->dev, endtime) &&
- !sk_has_rx_data(sk) &&
- vhost_vq_avail_empty(&net->dev, vq))
+ while (vhost_can_busy_poll(endtime)) {
+ if (vhost_has_work(&net->dev)) {
+ *busyloop_intr = true;
+ break;
+ }
+ if ((sk_has_rx_data(sk) &&
+ !vhost_vq_avail_empty(&net->dev, rvq)) ||
+ !vhost_vq_avail_empty(&net->dev, tvq))
+ break;
cpu_relax();
+ }
preempt_enable();
- if (!vhost_vq_avail_empty(&net->dev, vq))
- vhost_poll_queue(&vq->poll);
- else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
- vhost_disable_notify(&net->dev, vq);
- vhost_poll_queue(&vq->poll);
+ if (!vhost_vq_avail_empty(&net->dev, tvq)) {
+ vhost_poll_queue(&tvq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev, tvq))) {
+ vhost_disable_notify(&net->dev, tvq);
+ vhost_poll_queue(&tvq->poll);
}
- mutex_unlock(&vq->mutex);
+ mutex_unlock(&tvq->mutex);
- len = peek_head_len(rvq, sk);
+ len = peek_head_len(rnvq, sk);
}
return len;
@@ -786,6 +917,7 @@ static void handle_rx(struct vhost_net *net)
s16 headcount;
size_t vhost_hlen, sock_hlen;
size_t vhost_len, sock_len;
+ bool busyloop_intr = false;
struct socket *sock;
struct iov_iter fixup;
__virtio16 num_buffers;
@@ -809,7 +941,8 @@ static void handle_rx(struct vhost_net *net)
vq->log : NULL;
mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
- while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
+ while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+ &busyloop_intr))) {
sock_len += sock_hlen;
vhost_len = sock_len + vhost_hlen;
headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
@@ -820,7 +953,9 @@ static void handle_rx(struct vhost_net *net)
goto out;
/* OK, now we need to know about added descriptors. */
if (!headcount) {
- if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+ if (unlikely(busyloop_intr)) {
+ vhost_poll_queue(&vq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
/* They have slipped one in as we were
* doing that: check again. */
vhost_disable_notify(&net->dev, vq);
@@ -830,6 +965,7 @@ static void handle_rx(struct vhost_net *net)
* they refilled. */
goto out;
}
+ busyloop_intr = false;
if (nvq->rx_ring)
msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
/* On overrun, truncate and discard */
@@ -885,20 +1021,22 @@ static void handle_rx(struct vhost_net *net)
goto out;
}
nvq->done_idx += headcount;
- if (nvq->done_idx > VHOST_RX_BATCH)
- vhost_rx_signal_used(nvq);
+ if (nvq->done_idx > VHOST_NET_BATCH)
+ vhost_net_signal_used(nvq);
if (unlikely(vq_log))
vhost_log_write(vq, vq_log, log, vhost_len);
total_len += vhost_len;
- if (unlikely(total_len >= VHOST_NET_WEIGHT) ||
- unlikely(++recv_pkts >= VHOST_NET_PKT_WEIGHT)) {
+ if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
vhost_poll_queue(&vq->poll);
goto out;
}
}
- vhost_net_enable_vq(net, vq);
+ if (unlikely(busyloop_intr))
+ vhost_poll_queue(&vq->poll);
+ else
+ vhost_net_enable_vq(net, vq);
out:
- vhost_rx_signal_used(nvq);
+ vhost_net_signal_used(nvq);
mutex_unlock(&vq->mutex);
}
@@ -951,7 +1089,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
return -ENOMEM;
}
- queue = kmalloc_array(VHOST_RX_BATCH, sizeof(void *),
+ queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *),
GFP_KERNEL);
if (!queue) {
kfree(vqs);
@@ -1226,7 +1364,8 @@ err_used:
if (ubufs)
vhost_net_ubuf_put_wait_and_free(ubufs);
err_ubufs:
- sockfd_put(sock);
+ if (sock)
+ sockfd_put(sock);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -1264,6 +1403,21 @@ done:
return err;
}
+static int vhost_net_set_backend_features(struct vhost_net *n, u64 features)
+{
+ int i;
+
+ mutex_lock(&n->dev.mutex);
+ for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
+ mutex_lock(&n->vqs[i].vq.mutex);
+ n->vqs[i].vq.acked_backend_features = features;
+ mutex_unlock(&n->vqs[i].vq.mutex);
+ }
+ mutex_unlock(&n->dev.mutex);
+
+ return 0;
+}
+
static int vhost_net_set_features(struct vhost_net *n, u64 features)
{
size_t vhost_hlen, sock_hlen, hdr_len;
@@ -1354,6 +1508,17 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
if (features & ~VHOST_NET_FEATURES)
return -EOPNOTSUPP;
return vhost_net_set_features(n, features);
+ case VHOST_GET_BACKEND_FEATURES:
+ features = VHOST_NET_BACKEND_FEATURES;
+ if (copy_to_user(featurep, &features, sizeof(features)))
+ return -EFAULT;
+ return 0;
+ case VHOST_SET_BACKEND_FEATURES:
+ if (copy_from_user(&features, featurep, sizeof(features)))
+ return -EFAULT;
+ if (features & ~VHOST_NET_BACKEND_FEATURES)
+ return -EOPNOTSUPP;
+ return vhost_net_set_backend_features(n, features);
case VHOST_RESET_OWNER:
return vhost_net_reset_owner(n);
case VHOST_SET_OWNER:
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index a502f1af4a21..96c1d8400822 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -315,6 +315,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->log_addr = -1ull;
vq->private_data = NULL;
vq->acked_features = 0;
+ vq->acked_backend_features = 0;
vq->log_base = NULL;
vq->error_ctx = NULL;
vq->kick = NULL;
@@ -1027,28 +1028,40 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct iov_iter *from)
{
- struct vhost_msg_node node;
- unsigned size = sizeof(struct vhost_msg);
- size_t ret;
- int err;
+ struct vhost_iotlb_msg msg;
+ size_t offset;
+ int type, ret;
- if (iov_iter_count(from) < size)
- return 0;
- ret = copy_from_iter(&node.msg, size, from);
- if (ret != size)
+ ret = copy_from_iter(&type, sizeof(type), from);
+ if (ret != sizeof(type))
goto done;
- switch (node.msg.type) {
+ switch (type) {
case VHOST_IOTLB_MSG:
- err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
- if (err)
- ret = err;
+ /* There maybe a hole after type for V1 message type,
+ * so skip it here.
+ */
+ offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
+ break;
+ case VHOST_IOTLB_MSG_V2:
+ offset = sizeof(__u32);
break;
default:
ret = -EINVAL;
- break;
+ goto done;
+ }
+
+ iov_iter_advance(from, offset);
+ ret = copy_from_iter(&msg, sizeof(msg), from);
+ if (ret != sizeof(msg))
+ goto done;
+ if (vhost_process_iotlb_msg(dev, &msg)) {
+ ret = -EFAULT;
+ goto done;
}
+ ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
+ sizeof(struct vhost_msg_v2);
done:
return ret;
}
@@ -1107,13 +1120,28 @@ ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
finish_wait(&dev->wait, &wait);
if (node) {
- ret = copy_to_iter(&node->msg, size, to);
+ struct vhost_iotlb_msg *msg;
+ void *start = &node->msg;
- if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
+ switch (node->msg.type) {
+ case VHOST_IOTLB_MSG:
+ size = sizeof(node->msg);
+ msg = &node->msg.iotlb;
+ break;
+ case VHOST_IOTLB_MSG_V2:
+ size = sizeof(node->msg_v2);
+ msg = &node->msg_v2.iotlb;
+ break;
+ default:
+ BUG();
+ break;
+ }
+
+ ret = copy_to_iter(start, size, to);
+ if (ret != size || msg->type != VHOST_IOTLB_MISS) {
kfree(node);
return ret;
}
-
vhost_enqueue_msg(dev, &dev->pending_list, node);
}
@@ -1126,12 +1154,19 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
struct vhost_dev *dev = vq->dev;
struct vhost_msg_node *node;
struct vhost_iotlb_msg *msg;
+ bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
- node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
+ node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
if (!node)
return -ENOMEM;
- msg = &node->msg.iotlb;
+ if (v2) {
+ node->msg_v2.type = VHOST_IOTLB_MSG_V2;
+ msg = &node->msg_v2.iotlb;
+ } else {
+ msg = &node->msg.iotlb;
+ }
+
msg->type = VHOST_IOTLB_MISS;
msg->iova = iova;
msg->perm = access;
@@ -1560,9 +1595,12 @@ int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
d->iotlb = niotlb;
for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i]->mutex);
- d->vqs[i]->iotlb = niotlb;
- mutex_unlock(&d->vqs[i]->mutex);
+ struct vhost_virtqueue *vq = d->vqs[i];
+
+ mutex_lock(&vq->mutex);
+ vq->iotlb = niotlb;
+ __vhost_vq_meta_reset(vq);
+ mutex_unlock(&vq->mutex);
}
vhost_umem_clean(oiotlb);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 6c844b90a168..466ef7542291 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -132,6 +132,7 @@ struct vhost_virtqueue {
struct vhost_umem *iotlb;
void *private_data;
u64 acked_features;
+ u64 acked_backend_features;
/* Log write descriptors */
void __user *log_base;
struct vhost_log *log;
@@ -147,7 +148,10 @@ struct vhost_virtqueue {
};
struct vhost_msg_node {
- struct vhost_msg msg;
+ union {
+ struct vhost_msg msg;
+ struct vhost_msg_v2 msg_v2;
+ };
struct vhost_virtqueue *vq;
struct list_head node;
};
@@ -238,6 +242,11 @@ static inline bool vhost_has_feature(struct vhost_virtqueue *vq, int bit)
return vq->acked_features & (1ULL << bit);
}
+static inline bool vhost_backend_has_feature(struct vhost_virtqueue *vq, int bit)
+{
+ return vq->acked_backend_features & (1ULL << bit);
+}
+
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
static inline bool vhost_is_little_endian(struct vhost_virtqueue *vq)
{