diff options
Diffstat (limited to '')
-rw-r--r-- | drivers/vhost/vhost.c | 450 |
1 files changed, 370 insertions, 80 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index a92af08e7864..c71d573f1c94 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -187,13 +187,15 @@ EXPORT_SYMBOL_GPL(vhost_work_init); /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, - __poll_t mask, struct vhost_dev *dev) + __poll_t mask, struct vhost_dev *dev, + struct vhost_virtqueue *vq) { init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; poll->wqh = NULL; + poll->vq = vq; vhost_work_init(&poll->work, fn); } @@ -231,46 +233,102 @@ void vhost_poll_stop(struct vhost_poll *poll) } EXPORT_SYMBOL_GPL(vhost_poll_stop); -void vhost_dev_flush(struct vhost_dev *dev) +static void vhost_worker_queue(struct vhost_worker *worker, + struct vhost_work *work) +{ + if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { + /* We can only add the work to the list after we're + * sure it was not in the list. + * test_and_set_bit() implies a memory barrier. + */ + llist_add(&work->node, &worker->work_list); + vhost_task_wake(worker->vtsk); + } +} + +bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work) +{ + struct vhost_worker *worker; + bool queued = false; + + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker) { + queued = true; + vhost_worker_queue(worker, work); + } + rcu_read_unlock(); + + return queued; +} +EXPORT_SYMBOL_GPL(vhost_vq_work_queue); + +void vhost_vq_flush(struct vhost_virtqueue *vq) { struct vhost_flush_struct flush; - if (dev->worker) { - init_completion(&flush.wait_event); - vhost_work_init(&flush.work, vhost_flush_work); + init_completion(&flush.wait_event); + vhost_work_init(&flush.work, vhost_flush_work); - vhost_work_queue(dev, &flush.work); + if (vhost_vq_work_queue(vq, &flush.work)) wait_for_completion(&flush.wait_event); - } } -EXPORT_SYMBOL_GPL(vhost_dev_flush); +EXPORT_SYMBOL_GPL(vhost_vq_flush); -void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) +/** + * vhost_worker_flush - flush a worker + * @worker: worker to flush + * + * This does not use RCU to protect the worker, so the device or worker + * mutex must be held. + */ +static void vhost_worker_flush(struct vhost_worker *worker) { - if (!dev->worker) - return; + struct vhost_flush_struct flush; - if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { - /* We can only add the work to the list after we're - * sure it was not in the list. - * test_and_set_bit() implies a memory barrier. - */ - llist_add(&work->node, &dev->worker->work_list); - wake_up_process(dev->worker->vtsk->task); + init_completion(&flush.wait_event); + vhost_work_init(&flush.work, vhost_flush_work); + + vhost_worker_queue(worker, &flush.work); + wait_for_completion(&flush.wait_event); +} + +void vhost_dev_flush(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + unsigned long i; + + xa_for_each(&dev->worker_xa, i, worker) { + mutex_lock(&worker->mutex); + if (!worker->attachment_cnt) { + mutex_unlock(&worker->mutex); + continue; + } + vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); } } -EXPORT_SYMBOL_GPL(vhost_work_queue); +EXPORT_SYMBOL_GPL(vhost_dev_flush); /* A lockless hint for busy polling code to exit the loop */ -bool vhost_has_work(struct vhost_dev *dev) +bool vhost_vq_has_work(struct vhost_virtqueue *vq) { - return dev->worker && !llist_empty(&dev->worker->work_list); + struct vhost_worker *worker; + bool has_work = false; + + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker && !llist_empty(&worker->work_list)) + has_work = true; + rcu_read_unlock(); + + return has_work; } -EXPORT_SYMBOL_GPL(vhost_has_work); +EXPORT_SYMBOL_GPL(vhost_vq_has_work); void vhost_poll_queue(struct vhost_poll *poll) { - vhost_work_queue(poll->dev, &poll->work); + vhost_vq_work_queue(poll->vq, &poll->work); } EXPORT_SYMBOL_GPL(vhost_poll_queue); @@ -329,35 +387,26 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->busyloop_timeout = 0; vq->umem = NULL; vq->iotlb = NULL; + rcu_assign_pointer(vq->worker, NULL); vhost_vring_call_reset(&vq->call_ctx); __vhost_vq_meta_reset(vq); } -static int vhost_worker(void *data) +static bool vhost_worker(void *data) { struct vhost_worker *worker = data; struct vhost_work *work, *work_next; struct llist_node *node; - for (;;) { - /* mb paired w/ kthread_stop */ - set_current_state(TASK_INTERRUPTIBLE); - - if (vhost_task_should_stop(worker->vtsk)) { - __set_current_state(TASK_RUNNING); - break; - } - - node = llist_del_all(&worker->work_list); - if (!node) - schedule(); + node = llist_del_all(&worker->work_list); + if (node) { + __set_current_state(TASK_RUNNING); node = llist_reverse_order(node); /* make sure flag is seen after deletion */ smp_wmb(); llist_for_each_entry_safe(work, work_next, node, node) { clear_bit(VHOST_WORK_QUEUED, &work->flags); - __set_current_state(TASK_RUNNING); kcov_remote_start_common(worker->kcov_handle); work->fn(work); kcov_remote_stop(); @@ -365,7 +414,7 @@ static int vhost_worker(void *data) } } - return 0; + return !!node; } static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) @@ -468,7 +517,6 @@ void vhost_dev_init(struct vhost_dev *dev, dev->umem = NULL; dev->iotlb = NULL; dev->mm = NULL; - dev->worker = NULL; dev->iov_limit = iov_limit; dev->weight = weight; dev->byte_weight = byte_weight; @@ -478,7 +526,7 @@ void vhost_dev_init(struct vhost_dev *dev, INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); spin_lock_init(&dev->iotlb_lock); - + xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC); for (i = 0; i < dev->nvqs; ++i) { vq = dev->vqs[i]; @@ -490,7 +538,7 @@ void vhost_dev_init(struct vhost_dev *dev, vhost_vq_reset(dev, vq); if (vq->handle_kick) vhost_poll_init(&vq->poll, vq->handle_kick, - EPOLLIN, dev); + EPOLLIN, dev, vq); } } EXPORT_SYMBOL_GPL(vhost_dev_init); @@ -540,55 +588,284 @@ static void vhost_detach_mm(struct vhost_dev *dev) dev->mm = NULL; } -static void vhost_worker_free(struct vhost_dev *dev) +static void vhost_worker_destroy(struct vhost_dev *dev, + struct vhost_worker *worker) { - struct vhost_worker *worker = dev->worker; - if (!worker) return; - dev->worker = NULL; WARN_ON(!llist_empty(&worker->work_list)); + xa_erase(&dev->worker_xa, worker->id); vhost_task_stop(worker->vtsk); kfree(worker); } -static int vhost_worker_create(struct vhost_dev *dev) +static void vhost_workers_free(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + unsigned long i; + + if (!dev->use_worker) + return; + + for (i = 0; i < dev->nvqs; i++) + rcu_assign_pointer(dev->vqs[i]->worker, NULL); + /* + * Free the default worker we created and cleanup workers userspace + * created but couldn't clean up (it forgot or crashed). + */ + xa_for_each(&dev->worker_xa, i, worker) + vhost_worker_destroy(dev, worker); + xa_destroy(&dev->worker_xa); +} + +static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) { struct vhost_worker *worker; struct vhost_task *vtsk; char name[TASK_COMM_LEN]; int ret; + u32 id; worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); if (!worker) - return -ENOMEM; + return NULL; - dev->worker = worker; - worker->kcov_handle = kcov_common_handle(); - init_llist_head(&worker->work_list); snprintf(name, sizeof(name), "vhost-%d", current->pid); vtsk = vhost_task_create(vhost_worker, worker, name); - if (!vtsk) { - ret = -ENOMEM; + if (!vtsk) goto free_worker; - } + mutex_init(&worker->mutex); + init_llist_head(&worker->work_list); + worker->kcov_handle = kcov_common_handle(); worker->vtsk = vtsk; + vhost_task_start(vtsk); - return 0; + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) + goto stop_worker; + worker->id = id; + + return worker; + +stop_worker: + vhost_task_stop(vtsk); free_worker: kfree(worker); - dev->worker = NULL; + return NULL; +} + +/* Caller must have device mutex */ +static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_worker *worker) +{ + struct vhost_worker *old_worker; + + old_worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->dev->mutex)); + + mutex_lock(&worker->mutex); + worker->attachment_cnt++; + mutex_unlock(&worker->mutex); + rcu_assign_pointer(vq->worker, worker); + + if (!old_worker) + return; + /* + * Take the worker mutex to make sure we see the work queued from + * device wide flushes which doesn't use RCU for execution. + */ + mutex_lock(&old_worker->mutex); + old_worker->attachment_cnt--; + /* + * We don't want to call synchronize_rcu for every vq during setup + * because it will slow down VM startup. If we haven't done + * VHOST_SET_VRING_KICK and not done the driver specific + * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will + * not be any works queued for scsi and net. + */ + mutex_lock(&vq->mutex); + if (!vhost_vq_get_backend(vq) && !vq->kick) { + mutex_unlock(&vq->mutex); + mutex_unlock(&old_worker->mutex); + /* + * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID. + * Warn if it adds support for multiple workers but forgets to + * handle the early queueing case. + */ + WARN_ON(!old_worker->attachment_cnt && + !llist_empty(&old_worker->work_list)); + return; + } + mutex_unlock(&vq->mutex); + + /* Make sure new vq queue/flush/poll calls see the new worker */ + synchronize_rcu(); + /* Make sure whatever was queued gets run */ + vhost_worker_flush(old_worker); + mutex_unlock(&old_worker->mutex); +} + + /* Caller must have device mutex */ +static int vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_vring_worker *info) +{ + unsigned long index = info->worker_id; + struct vhost_dev *dev = vq->dev; + struct vhost_worker *worker; + + if (!dev->use_worker) + return -EINVAL; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + __vhost_vq_attach_worker(vq, worker); + return 0; +} + +/* Caller must have device mutex */ +static int vhost_new_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + struct vhost_worker *worker; + + worker = vhost_worker_create(dev); + if (!worker) + return -ENOMEM; + + info->worker_id = worker->id; + return 0; +} + +/* Caller must have device mutex */ +static int vhost_free_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + unsigned long index = info->worker_id; + struct vhost_worker *worker; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + mutex_lock(&worker->mutex); + if (worker->attachment_cnt) { + mutex_unlock(&worker->mutex); + return -EBUSY; + } + mutex_unlock(&worker->mutex); + + vhost_worker_destroy(dev, worker); + return 0; +} + +static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp, + struct vhost_virtqueue **vq, u32 *id) +{ + u32 __user *idxp = argp; + u32 idx; + long r; + + r = get_user(idx, idxp); + if (r < 0) + return r; + + if (idx >= dev->nvqs) + return -ENOBUFS; + + idx = array_index_nospec(idx, dev->nvqs); + + *vq = dev->vqs[idx]; + *id = idx; + return 0; +} + +/* Caller must have device mutex */ +long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, + void __user *argp) +{ + struct vhost_vring_worker ring_worker; + struct vhost_worker_state state; + struct vhost_worker *worker; + struct vhost_virtqueue *vq; + long ret; + u32 idx; + + if (!dev->use_worker) + return -EINVAL; + + if (!vhost_dev_has_owner(dev)) + return -EINVAL; + + ret = vhost_dev_check_owner(dev); + if (ret) + return ret; + + switch (ioctl) { + /* dev worker ioctls */ + case VHOST_NEW_WORKER: + ret = vhost_new_worker(dev, &state); + if (!ret && copy_to_user(argp, &state, sizeof(state))) + ret = -EFAULT; + return ret; + case VHOST_FREE_WORKER: + if (copy_from_user(&state, argp, sizeof(state))) + return -EFAULT; + return vhost_free_worker(dev, &state); + /* vring worker ioctls */ + case VHOST_ATTACH_VRING_WORKER: + case VHOST_GET_VRING_WORKER: + break; + default: + return -ENOIOCTLCMD; + } + + ret = vhost_get_vq_from_user(dev, argp, &vq, &idx); + if (ret) + return ret; + + switch (ioctl) { + case VHOST_ATTACH_VRING_WORKER: + if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) { + ret = -EFAULT; + break; + } + + ret = vhost_vq_attach_worker(vq, &ring_worker); + break; + case VHOST_GET_VRING_WORKER: + worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&dev->mutex)); + if (!worker) { + ret = -EINVAL; + break; + } + + ring_worker.index = idx; + ring_worker.worker_id = worker->id; + + if (copy_to_user(argp, &ring_worker, sizeof(ring_worker))) + ret = -EFAULT; + break; + default: + ret = -ENOIOCTLCMD; + break; + } + return ret; } +EXPORT_SYMBOL_GPL(vhost_worker_ioctl); /* Caller should have device mutex */ long vhost_dev_set_owner(struct vhost_dev *dev) { - int err; + struct vhost_worker *worker; + int err, i; /* Is there an owner already? */ if (vhost_dev_has_owner(dev)) { @@ -598,20 +875,32 @@ long vhost_dev_set_owner(struct vhost_dev *dev) vhost_attach_mm(dev); - if (dev->use_worker) { - err = vhost_worker_create(dev); - if (err) - goto err_worker; - } - err = vhost_dev_alloc_iovecs(dev); if (err) goto err_iovecs; + if (dev->use_worker) { + /* + * This should be done last, because vsock can queue work + * before VHOST_SET_OWNER so it simplifies the failure path + * below since we don't have to worry about vsock queueing + * while we free the worker. + */ + worker = vhost_worker_create(dev); + if (!worker) { + err = -ENOMEM; + goto err_worker; + } + + for (i = 0; i < dev->nvqs; i++) + __vhost_vq_attach_worker(dev->vqs[i], worker); + } + return 0; -err_iovecs: - vhost_worker_free(dev); + err_worker: + vhost_dev_free_iovecs(dev); +err_iovecs: vhost_detach_mm(dev); err_mm: return err; @@ -703,7 +992,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev) dev->iotlb = NULL; vhost_clear_msg(dev); wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); - vhost_worker_free(dev); + vhost_workers_free(dev); vhost_detach_mm(dev); } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); @@ -1591,21 +1880,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg struct file *eventfp, *filep = NULL; bool pollstart = false, pollstop = false; struct eventfd_ctx *ctx = NULL; - u32 __user *idxp = argp; struct vhost_virtqueue *vq; struct vhost_vring_state s; struct vhost_vring_file f; u32 idx; long r; - r = get_user(idx, idxp); + r = vhost_get_vq_from_user(d, argp, &vq, &idx); if (r < 0) return r; - if (idx >= d->nvqs) - return -ENOBUFS; - - idx = array_index_nospec(idx, d->nvqs); - vq = d->vqs[idx]; if (ioctl == VHOST_SET_VRING_NUM || ioctl == VHOST_SET_VRING_ADDR) { @@ -1626,17 +1909,25 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg r = -EFAULT; break; } - if (s.num > 0xffff) { - r = -EINVAL; - break; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + vq->last_avail_idx = s.num & 0xffff; + vq->last_used_idx = (s.num >> 16) & 0xffff; + } else { + if (s.num > 0xffff) { + r = -EINVAL; + break; + } + vq->last_avail_idx = s.num; } - vq->last_avail_idx = s.num; /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; break; case VHOST_GET_VRING_BASE: s.index = idx; - s.num = vq->last_avail_idx; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16); + else + s.num = vq->last_avail_idx; if (copy_to_user(argp, &s, sizeof s)) r = -EFAULT; break; @@ -2575,12 +2866,11 @@ EXPORT_SYMBOL_GPL(vhost_disable_notify); /* Create a new message. */ struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) { - struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); + /* Make sure all padding within the structure is initialized. */ + struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) return NULL; - /* Make sure all padding within the structure is initialized. */ - memset(&node->msg, 0, sizeof node->msg); node->vq = vq; node->msg.type = type; return node; |