aboutsummaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/virtio_transport_common.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/virtio_transport_common.c')
-rw-r--r--net/vmw_vsock/virtio_transport_common.c247
1 files changed, 208 insertions, 39 deletions
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index f3c4bab2f737..a9980e9b9304 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -74,6 +74,14 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err)
goto out;
+
+ if (msg_data_left(info->msg) == 0 &&
+ info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
+ pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
+
+ if (info->msg->msg_flags & MSG_EOR)
+ pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
+ }
}
trace_virtio_transport_alloc_pkt(src_cid, src_port,
@@ -157,10 +165,22 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque)
void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
{
+ if (pkt->tap_delivered)
+ return;
+
vsock_deliver_tap(virtio_transport_build_skb, pkt);
+ pkt->tap_delivered = true;
}
EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
+static u16 virtio_transport_get_type(struct sock *sk)
+{
+ if (sk->sk_type == SOCK_STREAM)
+ return VIRTIO_VSOCK_TYPE_STREAM;
+ else
+ return VIRTIO_VSOCK_TYPE_SEQPACKET;
+}
+
/* This function can only be used on connecting/connected sockets,
* since a socket assigned to a transport is required.
*
@@ -175,6 +195,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
+ info->type = virtio_transport_get_type(sk_vsock(vsk));
+
t_ops = virtio_transport_get_ops(vsk);
if (unlikely(!t_ops))
return -EFAULT;
@@ -265,13 +287,10 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
}
EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
-static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
- int type,
- struct virtio_vsock_hdr *hdr)
+static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
- .type = type,
.vsk = vsk,
};
@@ -379,11 +398,8 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
* messages, we set the limit to a high value. TODO: experiment
* with different values.
*/
- if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
- virtio_transport_send_credit_update(vsk,
- VIRTIO_VSOCK_TYPE_STREAM,
- NULL);
- }
+ if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
+ virtio_transport_send_credit_update(vsk);
return total;
@@ -393,6 +409,78 @@ out:
return err;
}
+static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
+ struct msghdr *msg,
+ int flags)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ struct virtio_vsock_pkt *pkt;
+ int dequeued_len = 0;
+ size_t user_buf_len = msg_data_left(msg);
+ bool msg_ready = false;
+
+ spin_lock_bh(&vvs->rx_lock);
+
+ if (vvs->msg_count == 0) {
+ spin_unlock_bh(&vvs->rx_lock);
+ return 0;
+ }
+
+ while (!msg_ready) {
+ pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
+
+ if (dequeued_len >= 0) {
+ size_t pkt_len;
+ size_t bytes_to_copy;
+
+ pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
+ bytes_to_copy = min(user_buf_len, pkt_len);
+
+ if (bytes_to_copy) {
+ int err;
+
+ /* sk_lock is held by caller so no one else can dequeue.
+ * Unlock rx_lock since memcpy_to_msg() may sleep.
+ */
+ spin_unlock_bh(&vvs->rx_lock);
+
+ err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
+ if (err) {
+ /* Copy of message failed. Rest of
+ * fragments will be freed without copy.
+ */
+ dequeued_len = err;
+ } else {
+ user_buf_len -= bytes_to_copy;
+ }
+
+ spin_lock_bh(&vvs->rx_lock);
+ }
+
+ if (dequeued_len >= 0)
+ dequeued_len += pkt_len;
+ }
+
+ if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
+ msg_ready = true;
+ vvs->msg_count--;
+
+ if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)
+ msg->msg_flags |= MSG_EOR;
+ }
+
+ virtio_transport_dec_rx_pkt(vvs, pkt);
+ list_del(&pkt->list);
+ virtio_transport_free_pkt(pkt);
+ }
+
+ spin_unlock_bh(&vvs->rx_lock);
+
+ virtio_transport_send_credit_update(vsk);
+
+ return dequeued_len;
+}
+
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
@@ -405,6 +493,38 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk,
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
+ssize_t
+virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
+ struct msghdr *msg,
+ int flags)
+{
+ if (flags & MSG_PEEK)
+ return -EOPNOTSUPP;
+
+ return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
+
+int
+virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
+ struct msghdr *msg,
+ size_t len)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+
+ spin_lock_bh(&vvs->tx_lock);
+
+ if (len > vvs->peer_buf_alloc) {
+ spin_unlock_bh(&vvs->tx_lock);
+ return -EMSGSIZE;
+ }
+
+ spin_unlock_bh(&vvs->tx_lock);
+
+ return virtio_transport_stream_enqueue(vsk, msg, len);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
+
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
@@ -427,6 +547,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
+u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ u32 msg_count;
+
+ spin_lock_bh(&vvs->rx_lock);
+ msg_count = vvs->msg_count;
+ spin_unlock_bh(&vvs->rx_lock);
+
+ return msg_count;
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
+
static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
@@ -492,8 +625,7 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
vvs->buf_alloc = *val;
- virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
- NULL);
+ virtio_transport_send_credit_update(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
@@ -502,10 +634,7 @@ virtio_transport_notify_poll_in(struct vsock_sock *vsk,
size_t target,
bool *data_ready_now)
{
- if (vsock_stream_has_data(vsk))
- *data_ready_now = true;
- else
- *data_ready_now = false;
+ *data_ready_now = vsock_stream_has_data(vsk) >= target;
return 0;
}
@@ -620,7 +749,6 @@ int virtio_transport_connect(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
- .type = VIRTIO_VSOCK_TYPE_STREAM,
.vsk = vsk,
};
@@ -632,7 +760,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN,
- .type = VIRTIO_VSOCK_TYPE_STREAM,
.flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ?
@@ -661,7 +788,6 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
- .type = VIRTIO_VSOCK_TYPE_STREAM,
.msg = msg,
.pkt_len = len,
.vsk = vsk,
@@ -684,7 +810,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
- .type = VIRTIO_VSOCK_TYPE_STREAM,
.reply = !!pkt,
.vsk = vsk,
};
@@ -729,6 +854,23 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
return t->send_pkt(reply);
}
+/* This function should be called with sk_lock held and SOCK_DONE set */
+static void virtio_transport_remove_sock(struct vsock_sock *vsk)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ struct virtio_vsock_pkt *pkt, *tmp;
+
+ /* We don't need to take rx_lock, as the socket is closing and we are
+ * removing it.
+ */
+ list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
+ list_del(&pkt->list);
+ virtio_transport_free_pkt(pkt);
+ }
+
+ vsock_remove_sock(vsk);
+}
+
static void virtio_transport_wait_close(struct sock *sk, long timeout)
{
if (timeout) {
@@ -761,7 +903,7 @@ static void virtio_transport_do_close(struct vsock_sock *vsk,
(!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
vsk->close_work_scheduled = false;
- vsock_remove_sock(vsk);
+ virtio_transport_remove_sock(vsk);
/* Release refcnt obtained when we scheduled the timeout */
sock_put(sk);
@@ -824,21 +966,16 @@ static bool virtio_transport_close(struct vsock_sock *vsk)
void virtio_transport_release(struct vsock_sock *vsk)
{
- struct virtio_vsock_sock *vvs = vsk->trans;
- struct virtio_vsock_pkt *pkt, *tmp;
struct sock *sk = &vsk->sk;
bool remove_sock = true;
- if (sk->sk_type == SOCK_STREAM)
+ if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
remove_sock = virtio_transport_close(vsk);
- list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
- list_del(&pkt->list);
- virtio_transport_free_pkt(pkt);
+ if (remove_sock) {
+ sock_set_flag(sk, SOCK_DONE);
+ virtio_transport_remove_sock(vsk);
}
-
- if (remove_sock)
- vsock_remove_sock(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_release);
@@ -874,7 +1011,7 @@ destroy:
virtio_transport_reset(vsk, pkt);
sk->sk_state = TCP_CLOSE;
sk->sk_err = skerr;
- sk->sk_error_report(sk);
+ sk_error_report(sk);
return err;
}
@@ -896,6 +1033,9 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
goto out;
}
+ if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM)
+ vvs->msg_count++;
+
/* Try to copy small packets into the buffer of last packet queued,
* to avoid wasting memory queueing the entire buffer with a small
* payload.
@@ -907,13 +1047,18 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
struct virtio_vsock_pkt, list);
/* If there is space in the last packet queued, we copy the
- * new packet in its buffer.
+ * new packet in its buffer. We avoid this if the last packet
+ * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
+ * delimiter of SEQPACKET message, so 'pkt' is the first packet
+ * of a new message.
*/
- if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
+ if ((pkt->len <= last_pkt->buf_len - last_pkt->len) &&
+ !(le32_to_cpu(last_pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM)) {
memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
pkt->len);
last_pkt->len += pkt->len;
free_pkt = true;
+ last_pkt->hdr.flags |= pkt->hdr.flags;
goto out;
}
}
@@ -936,8 +1081,11 @@ virtio_transport_recv_connected(struct sock *sk,
switch (le16_to_cpu(pkt->hdr.op)) {
case VIRTIO_VSOCK_OP_RW:
virtio_transport_recv_enqueue(vsk, pkt);
- sk->sk_data_ready(sk);
+ vsock_data_ready(sk);
return err;
+ case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
+ virtio_transport_send_credit_update(vsk);
+ break;
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
sk->sk_write_space(sk);
break;
@@ -984,7 +1132,6 @@ virtio_transport_send_response(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE,
- .type = VIRTIO_VSOCK_TYPE_STREAM,
.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
.remote_port = le32_to_cpu(pkt->hdr.src_port),
.reply = true,
@@ -1080,6 +1227,12 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
return 0;
}
+static bool virtio_transport_valid_type(u16 type)
+{
+ return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
+ (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
+}
+
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
@@ -1105,7 +1258,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
le32_to_cpu(pkt->hdr.buf_alloc),
le32_to_cpu(pkt->hdr.fwd_cnt));
- if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
+ if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
(void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
@@ -1122,14 +1275,29 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
}
}
- vsk = vsock_sk(sk);
+ if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type)) {
+ (void)virtio_transport_reset_no_sock(t, pkt);
+ sock_put(sk);
+ goto free_pkt;
+ }
- space_available = virtio_transport_space_update(sk, pkt);
+ vsk = vsock_sk(sk);
lock_sock(sk);
+ /* Check if sk has been closed before lock_sock */
+ if (sock_flag(sk, SOCK_DONE)) {
+ (void)virtio_transport_reset_no_sock(t, pkt);
+ release_sock(sk);
+ sock_put(sk);
+ goto free_pkt;
+ }
+
+ space_available = virtio_transport_space_update(sk, pkt);
+
/* Update CID in case it has changed after a transport reset event */
- vsk->local_addr.svm_cid = dst.svm_cid;
+ if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
+ vsk->local_addr.svm_cid = dst.svm_cid;
if (space_available)
sk->sk_write_space(sk);
@@ -1151,6 +1319,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
virtio_transport_free_pkt(pkt);
break;
default:
+ (void)virtio_transport_reset_no_sock(t, pkt);
virtio_transport_free_pkt(pkt);
break;
}
@@ -1170,7 +1339,7 @@ EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
{
- kfree(pkt->buf);
+ kvfree(pkt->buf);
kfree(pkt);
}
EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);