aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c6
-rw-r--r--drivers/vhost/scsi.c22
-rw-r--r--drivers/vhost/vhost.c112
-rw-r--r--drivers/vhost/vhost.h7
-rw-r--r--drivers/vhost/vsock.c4
5 files changed, 117 insertions, 34 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 36f3d0f49e60..df51a35cf537 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1236,7 +1236,8 @@ static void handle_rx(struct vhost_net *net)
if (nvq->done_idx > VHOST_NET_BATCH)
vhost_net_signal_used(nvq);
if (unlikely(vq_log))
- vhost_log_write(vq, vq_log, log, vhost_len);
+ vhost_log_write(vq, vq_log, log, vhost_len,
+ vq->iov, in);
total_len += vhost_len;
if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
vhost_poll_queue(&vq->poll);
@@ -1336,7 +1337,8 @@ static int vhost_net_open(struct inode *inode, struct file *f)
n->vqs[i].rx_ring = NULL;
vhost_net_buf_init(&n->vqs[i].rxq);
}
- vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
+ vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
+ UIO_MAXIOV + VHOST_NET_BATCH);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 8e10ab436d1f..23593cb23dd0 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -1127,16 +1127,18 @@ vhost_scsi_send_tmf_reject(struct vhost_scsi *vs,
struct vhost_virtqueue *vq,
struct vhost_scsi_ctx *vc)
{
- struct virtio_scsi_ctrl_tmf_resp __user *resp;
struct virtio_scsi_ctrl_tmf_resp rsp;
+ struct iov_iter iov_iter;
int ret;
pr_debug("%s\n", __func__);
memset(&rsp, 0, sizeof(rsp));
rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED;
- resp = vq->iov[vc->out].iov_base;
- ret = __copy_to_user(resp, &rsp, sizeof(rsp));
- if (!ret)
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
else
pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n");
@@ -1147,16 +1149,18 @@ vhost_scsi_send_an_resp(struct vhost_scsi *vs,
struct vhost_virtqueue *vq,
struct vhost_scsi_ctx *vc)
{
- struct virtio_scsi_ctrl_an_resp __user *resp;
struct virtio_scsi_ctrl_an_resp rsp;
+ struct iov_iter iov_iter;
int ret;
pr_debug("%s\n", __func__);
memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */
rsp.response = VIRTIO_SCSI_S_OK;
- resp = vq->iov[vc->out].iov_base;
- ret = __copy_to_user(resp, &rsp, sizeof(rsp));
- if (!ret)
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
else
pr_err("Faulted on virtio_scsi_ctrl_an_resp\n");
@@ -1623,7 +1627,7 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
vqs[i] = &vs->vqs[i].vq;
vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
}
- vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
+ vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV);
vhost_scsi_init_inflight(vs, NULL);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 9f7942cbcbb2..a2e5dc7716e2 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -390,9 +390,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
vq->indirect = kmalloc_array(UIO_MAXIOV,
sizeof(*vq->indirect),
GFP_KERNEL);
- vq->log = kmalloc_array(UIO_MAXIOV, sizeof(*vq->log),
+ vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
GFP_KERNEL);
- vq->heads = kmalloc_array(UIO_MAXIOV, sizeof(*vq->heads),
+ vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
GFP_KERNEL);
if (!vq->indirect || !vq->log || !vq->heads)
goto err_nomem;
@@ -414,7 +414,7 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
}
void vhost_dev_init(struct vhost_dev *dev,
- struct vhost_virtqueue **vqs, int nvqs)
+ struct vhost_virtqueue **vqs, int nvqs, int iov_limit)
{
struct vhost_virtqueue *vq;
int i;
@@ -427,6 +427,7 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->iotlb = NULL;
dev->mm = NULL;
dev->worker = NULL;
+ dev->iov_limit = iov_limit;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
@@ -1034,8 +1035,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
int type, ret;
ret = copy_from_iter(&type, sizeof(type), from);
- if (ret != sizeof(type))
+ if (ret != sizeof(type)) {
+ ret = -EINVAL;
goto done;
+ }
switch (type) {
case VHOST_IOTLB_MSG:
@@ -1054,8 +1057,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
iov_iter_advance(from, offset);
ret = copy_from_iter(&msg, sizeof(msg), from);
- if (ret != sizeof(msg))
+ if (ret != sizeof(msg)) {
+ ret = -EINVAL;
goto done;
+ }
if (vhost_process_iotlb_msg(dev, &msg)) {
ret = -EFAULT;
goto done;
@@ -1733,13 +1738,87 @@ static int log_write(void __user *log_base,
return r;
}
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+ struct vhost_umem *umem = vq->umem;
+ struct vhost_umem_node *u;
+ u64 start, end, l, min;
+ int r;
+ bool hit = false;
+
+ while (len) {
+ min = len;
+ /* More than one GPAs can be mapped into a single HVA. So
+ * iterate all possible umems here to be safe.
+ */
+ list_for_each_entry(u, &umem->umem_list, link) {
+ if (u->userspace_addr > hva - 1 + len ||
+ u->userspace_addr - 1 + u->size < hva)
+ continue;
+ start = max(u->userspace_addr, hva);
+ end = min(u->userspace_addr - 1 + u->size,
+ hva - 1 + len);
+ l = end - start + 1;
+ r = log_write(vq->log_base,
+ u->start + start - u->userspace_addr,
+ l);
+ if (r < 0)
+ return r;
+ hit = true;
+ min = min(l, min);
+ }
+
+ if (!hit)
+ return -EFAULT;
+
+ len -= min;
+ hva += min;
+ }
+
+ return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+ struct iovec iov[64];
+ int i, ret;
+
+ if (!vq->iotlb)
+ return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+ ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+ len, iov, 64, VHOST_ACCESS_WO);
+ if (ret < 0)
+ return ret;
+
+ for (i = 0; i < ret; i++) {
+ ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (ret)
+ return ret;
+ }
+
+ return 0;
+}
+
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len)
+ unsigned int log_num, u64 len, struct iovec *iov, int count)
{
int i, r;
/* Make sure data written is seen before log. */
smp_wmb();
+
+ if (vq->iotlb) {
+ for (i = 0; i < count; i++) {
+ r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (r < 0)
+ return r;
+ }
+ return 0;
+ }
+
for (i = 0; i < log_num; ++i) {
u64 l = min(log[i].len, len);
r = log_write(vq->log_base, log[i].addr, l);
@@ -1769,9 +1848,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
smp_wmb();
/* Log used flag write. */
used = &vq->used->flags;
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof vq->used->flags);
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1789,9 +1867,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
smp_wmb();
/* Log avail event write */
used = vhost_avail_event(vq);
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof *vhost_avail_event(vq));
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -2191,10 +2268,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
/* Make sure data is seen before log. */
smp_wmb();
/* Log used ring entry write. */
- log_write(vq->log_base,
- vq->log_addr +
- ((void __user *)used - (void __user *)vq->used),
- count * sizeof *used);
+ log_used(vq, ((void __user *)used - (void __user *)vq->used),
+ count * sizeof *used);
}
old = vq->last_used_idx;
new = (vq->last_used_idx += count);
@@ -2236,9 +2311,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
/* Make sure used idx is seen before log. */
smp_wmb();
/* Log used index update. */
- log_write(vq->log_base,
- vq->log_addr + offsetof(struct vring_used, idx),
- sizeof vq->used->idx);
+ log_used(vq, offsetof(struct vring_used, idx),
+ sizeof vq->used->idx);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 466ef7542291..9490e7ddb340 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -170,9 +170,11 @@ struct vhost_dev {
struct list_head read_list;
struct list_head pending_list;
wait_queue_head_t wait;
+ int iov_limit;
};
-void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
+void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
+ int nvqs, int iov_limit);
long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *);
@@ -205,7 +207,8 @@ bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *);
bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len);
+ unsigned int log_num, u64 len,
+ struct iovec *iov, int count);
int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index bc42d38ae031..bb5fc0e9fbc2 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -531,7 +531,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
- vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
+ vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), UIO_MAXIOV);
file->private_data = vsock;
spin_lock_init(&vsock->send_pkt_list_lock);
@@ -642,7 +642,7 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
hash_del_rcu(&vsock->hash);
vsock->guest_cid = guest_cid;
- hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
+ hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
mutex_unlock(&vhost_vsock_mutex);
return 0;