aboutsummaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/virtio_transport_common.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/virtio_transport_common.c')
-rw-r--r--net/vmw_vsock/virtio_transport_common.c223
1 files changed, 120 insertions, 103 deletions
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index fb2060dffb0a..e5ea29c6bca7 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -29,9 +29,10 @@
/* Threshold for detecting small packets to copy */
#define GOOD_COPY_LEN 128
-static const struct virtio_transport *virtio_transport_get_ops(void)
+static const struct virtio_transport *
+virtio_transport_get_ops(struct vsock_sock *vsk)
{
- const struct vsock_transport *t = vsock_core_get_transport();
+ const struct vsock_transport *t = vsock_core_get_transport(vsk);
return container_of(t, struct virtio_transport, transport);
}
@@ -168,7 +169,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
- src_cid = vm_sockets_get_local_cid();
+ src_cid = virtio_transport_get_ops(vsk)->transport.get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
dst_cid = vsk->remote_addr.svm_cid;
@@ -201,7 +202,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
virtio_transport_inc_tx_pkt(vvs, pkt);
- return virtio_transport_get_ops()->send_pkt(pkt);
+ return virtio_transport_get_ops(vsk)->send_pkt(pkt);
}
static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
@@ -268,6 +269,55 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
}
static ssize_t
+virtio_transport_stream_do_peek(struct vsock_sock *vsk,
+ struct msghdr *msg,
+ size_t len)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ struct virtio_vsock_pkt *pkt;
+ size_t bytes, total = 0, off;
+ int err = -EFAULT;
+
+ spin_lock_bh(&vvs->rx_lock);
+
+ list_for_each_entry(pkt, &vvs->rx_queue, list) {
+ off = pkt->off;
+
+ if (total == len)
+ break;
+
+ while (total < len && off < pkt->len) {
+ bytes = len - total;
+ if (bytes > pkt->len - off)
+ bytes = pkt->len - off;
+
+ /* sk_lock is held by caller so no one else can dequeue.
+ * Unlock rx_lock since memcpy_to_msg() may sleep.
+ */
+ spin_unlock_bh(&vvs->rx_lock);
+
+ err = memcpy_to_msg(msg, pkt->buf + off, bytes);
+ if (err)
+ goto out;
+
+ spin_lock_bh(&vvs->rx_lock);
+
+ total += bytes;
+ off += bytes;
+ }
+ }
+
+ spin_unlock_bh(&vvs->rx_lock);
+
+ return total;
+
+out:
+ if (total)
+ err = total;
+ return err;
+}
+
+static ssize_t
virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
@@ -339,9 +389,9 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk,
size_t len, int flags)
{
if (flags & MSG_PEEK)
- return -EOPNOTSUPP;
-
- return virtio_transport_stream_do_dequeue(vsk, msg, len);
+ return virtio_transport_stream_do_peek(vsk, msg, len);
+ else
+ return virtio_transport_stream_do_dequeue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
@@ -403,20 +453,16 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk,
vsk->trans = vvs;
vvs->vsk = vsk;
- if (psk) {
+ if (psk && psk->trans) {
struct virtio_vsock_sock *ptrans = psk->trans;
- vvs->buf_size = ptrans->buf_size;
- vvs->buf_size_min = ptrans->buf_size_min;
- vvs->buf_size_max = ptrans->buf_size_max;
vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
- } else {
- vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
- vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
- vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
}
- vvs->buf_alloc = vvs->buf_size;
+ if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
+ vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
+
+ vvs->buf_alloc = vsk->buffer_size;
spin_lock_init(&vvs->rx_lock);
spin_lock_init(&vvs->tx_lock);
@@ -426,71 +472,20 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk,
}
EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
-u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
-{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- return vvs->buf_size;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
-
-u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
-{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- return vvs->buf_size_min;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
-
-u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
+/* sk_lock held by the caller */
+void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
- return vvs->buf_size_max;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
+ if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
+ *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
-void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
- if (val < vvs->buf_size_min)
- vvs->buf_size_min = val;
- if (val > vvs->buf_size_max)
- vvs->buf_size_max = val;
- vvs->buf_size = val;
- vvs->buf_alloc = val;
+ vvs->buf_alloc = *val;
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
NULL);
}
-EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
-
-void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
- if (val > vvs->buf_size)
- vvs->buf_size = val;
- vvs->buf_size_min = val;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
-
-void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
- if (val < vvs->buf_size)
- vvs->buf_size = val;
- vvs->buf_size_max = val;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
+EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
int
virtio_transport_notify_poll_in(struct vsock_sock *vsk,
@@ -582,9 +577,7 @@ EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
{
- struct virtio_vsock_sock *vvs = vsk->trans;
-
- return vvs->buf_size;
+ return vsk->buffer_size;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
@@ -696,9 +689,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
/* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist.
*/
-static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
+static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
+ struct virtio_vsock_pkt *pkt)
{
- const struct virtio_transport *t;
struct virtio_vsock_pkt *reply;
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
@@ -718,7 +711,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
if (!reply)
return -ENOMEM;
- t = virtio_transport_get_ops();
if (!t) {
virtio_transport_free_pkt(reply);
return -ENOTCONN;
@@ -994,13 +986,39 @@ virtio_transport_send_response(struct vsock_sock *vsk,
return virtio_transport_send_pkt_info(vsk, &info);
}
+static bool virtio_transport_space_update(struct sock *sk,
+ struct virtio_vsock_pkt *pkt)
+{
+ struct vsock_sock *vsk = vsock_sk(sk);
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ bool space_available;
+
+ /* Listener sockets are not associated with any transport, so we are
+ * not able to take the state to see if there is space available in the
+ * remote peer, but since they are only used to receive requests, we
+ * can assume that there is always space available in the other peer.
+ */
+ if (!vvs)
+ return true;
+
+ /* buf_alloc and fwd_cnt is always included in the hdr */
+ spin_lock_bh(&vvs->tx_lock);
+ vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
+ vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
+ space_available = virtio_transport_has_space(vsk);
+ spin_unlock_bh(&vvs->tx_lock);
+ return space_available;
+}
+
/* Handle server socket */
static int
-virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
+virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
+ struct virtio_transport *t)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct vsock_sock *vchild;
struct sock *child;
+ int ret;
if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
virtio_transport_reset(vsk, pkt);
@@ -1012,14 +1030,13 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
return -ENOMEM;
}
- child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
- sk->sk_type, 0);
+ child = vsock_create_connected(sk);
if (!child) {
virtio_transport_reset(vsk, pkt);
return -ENOMEM;
}
- sk->sk_ack_backlog++;
+ sk_acceptq_added(sk);
lock_sock_nested(child, SINGLE_DEPTH_NESTING);
@@ -1031,6 +1048,20 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port));
+ ret = vsock_assign_transport(vchild, vsk);
+ /* Transport assigned (looking at remote_addr) must be the same
+ * where we received the request.
+ */
+ if (ret || vchild->transport != &t->transport) {
+ release_sock(child);
+ virtio_transport_reset(vsk, pkt);
+ sock_put(child);
+ return ret;
+ }
+
+ if (virtio_transport_space_update(child, pkt))
+ child->sk_write_space(child);
+
vsock_insert_connected(vchild);
vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, pkt);
@@ -1041,26 +1072,11 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
return 0;
}
-static bool virtio_transport_space_update(struct sock *sk,
- struct virtio_vsock_pkt *pkt)
-{
- struct vsock_sock *vsk = vsock_sk(sk);
- struct virtio_vsock_sock *vvs = vsk->trans;
- bool space_available;
-
- /* buf_alloc and fwd_cnt is always included in the hdr */
- spin_lock_bh(&vvs->tx_lock);
- vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
- vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
- space_available = virtio_transport_has_space(vsk);
- spin_unlock_bh(&vvs->tx_lock);
- return space_available;
-}
-
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
-void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
+void virtio_transport_recv_pkt(struct virtio_transport *t,
+ struct virtio_vsock_pkt *pkt)
{
struct sockaddr_vm src, dst;
struct vsock_sock *vsk;
@@ -1082,7 +1098,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
- (void)virtio_transport_reset_no_sock(pkt);
+ (void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
@@ -1093,7 +1109,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
if (!sk) {
sk = vsock_find_bound_socket(&dst);
if (!sk) {
- (void)virtio_transport_reset_no_sock(pkt);
+ (void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
}
@@ -1112,7 +1128,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
switch (sk->sk_state) {
case TCP_LISTEN:
- virtio_transport_recv_listen(sk, pkt);
+ virtio_transport_recv_listen(sk, pkt, t);
virtio_transport_free_pkt(pkt);
break;
case TCP_SYN_SENT:
@@ -1130,6 +1146,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
virtio_transport_free_pkt(pkt);
break;
}
+
release_sock(sk);
/* Release refcnt obtained when we fetched this socket out of the