aboutsummaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/af_vsock.c
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2019-11-25 20:02:57 -0800
committerLinus Torvalds <torvalds@linux-foundation.org>2019-11-25 20:02:57 -0800
commit386403a115f95997c2715691226e11a7b5cffcfd (patch)
treea685df70bd3d5b295683713818ddf0752c3d75b6 /net/vmw_vsock/af_vsock.c
parentMerge git://git.kernel.org/pub/scm/linux/kernel/git/herbert/crypto-2.6 (diff)
parentMerge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next (diff)
downloadlinux-dev-386403a115f95997c2715691226e11a7b5cffcfd.tar.xz
linux-dev-386403a115f95997c2715691226e11a7b5cffcfd.zip
Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net-next
Pull networking updates from David Miller: "Another merge window, another pull full of stuff: 1) Support alternative names for network devices, from Jiri Pirko. 2) Introduce per-netns netdev notifiers, also from Jiri Pirko. 3) Support MSG_PEEK in vsock/virtio, from Matias Ezequiel Vara Larsen. 4) Allow compiling out the TLS TOE code, from Jakub Kicinski. 5) Add several new tracepoints to the kTLS code, also from Jakub. 6) Support set channels ethtool callback in ena driver, from Sameeh Jubran. 7) New SCTP events SCTP_ADDR_ADDED, SCTP_ADDR_REMOVED, SCTP_ADDR_MADE_PRIM, and SCTP_SEND_FAILED_EVENT. From Xin Long. 8) Add XDP support to mvneta driver, from Lorenzo Bianconi. 9) Lots of netfilter hw offload fixes, cleanups and enhancements, from Pablo Neira Ayuso. 10) PTP support for aquantia chips, from Egor Pomozov. 11) Add UDP segmentation offload support to igb, ixgbe, and i40e. From Josh Hunt. 12) Add smart nagle to tipc, from Jon Maloy. 13) Support L2 field rewrite by TC offloads in bnxt_en, from Venkat Duvvuru. 14) Add a flow mask cache to OVS, from Tonghao Zhang. 15) Add XDP support to ice driver, from Maciej Fijalkowski. 16) Add AF_XDP support to ice driver, from Krzysztof Kazimierczak. 17) Support UDP GSO offload in atlantic driver, from Igor Russkikh. 18) Support it in stmmac driver too, from Jose Abreu. 19) Support TIPC encryption and auth, from Tuong Lien. 20) Introduce BPF trampolines, from Alexei Starovoitov. 21) Make page_pool API more numa friendly, from Saeed Mahameed. 22) Introduce route hints to ipv4 and ipv6, from Paolo Abeni. 23) Add UDP segmentation offload to cxgb4, Rahul Lakkireddy" * git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net-next: (1857 commits) libbpf: Fix usage of u32 in userspace code mm: Implement no-MMU variant of vmalloc_user_node_flags slip: Fix use-after-free Read in slip_open net: dsa: sja1105: fix sja1105_parse_rgmii_delays() macvlan: schedule bc_work even if error enetc: add support Credit Based Shaper(CBS) for hardware offload net: phy: add helpers phy_(un)lock_mdio_bus mdio_bus: don't use managed reset-controller ax88179_178a: add ethtool_op_get_ts_info() mlxsw: spectrum_router: Fix use of uninitialized adjacency index mlxsw: spectrum_router: After underlay moves, demote conflicting tunnels bpf: Simplify __bpf_arch_text_poke poke type handling bpf: Introduce BPF_TRACE_x helper for the tracing tests bpf: Add bpf_jit_blinding_enabled for !CONFIG_BPF_JIT bpf, testing: Add various tail call test cases bpf, x86: Emit patchable direct jump as tail call bpf: Constant map key tracking for prog array pokes bpf: Add poke dependency tracking for prog array maps bpf: Add initial poke descriptor table for jit images bpf: Move owner type, jited info into array auxiliary data ...
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r--net/vmw_vsock/af_vsock.c397
1 files changed, 295 insertions, 102 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 582a3e4dfce2..74db4cd637a7 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -126,19 +126,18 @@ static struct proto vsock_proto = {
*/
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
-static const struct vsock_transport *transport;
+#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
+
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
static DEFINE_MUTEX(vsock_register_mutex);
-/**** EXPORTS ****/
-
-/* Get the ID of the local context. This is transport dependent. */
-
-int vm_sockets_get_local_cid(void)
-{
- return transport->get_local_cid();
-}
-EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
-
/**** UTILS ****/
/* Each bound VSocket is stored in the bind hash table and each connected
@@ -188,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk)
return __vsock_bind(sk, &local_addr);
}
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
{
int i;
@@ -197,7 +196,6 @@ static int __init vsock_init_tables(void)
for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
INIT_LIST_HEAD(&vsock_connected_table[i]);
- return 0;
}
static void __vsock_insert_bound(struct list_head *list,
@@ -230,9 +228,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
{
struct vsock_sock *vsk;
- list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
- if (addr->svm_port == vsk->local_addr.svm_port)
+ list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
+ if (vsock_addr_equals_addr(addr, &vsk->local_addr))
+ return sk_vsock(vsk);
+
+ if (addr->svm_port == vsk->local_addr.svm_port &&
+ (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
+ addr->svm_cid == VMADDR_CID_ANY))
return sk_vsock(vsk);
+ }
return NULL;
}
@@ -382,6 +386,88 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
}
EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+ if (!vsk->transport)
+ return;
+
+ vsk->transport->destruct(vsk);
+ module_put(vsk->transport->module);
+ vsk->transport = NULL;
+}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ * - remote CID == local_cid (guest->host transport) will use guest->host
+ * transport for loopback (host->guest transports don't support loopback);
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ const struct vsock_transport *new_transport;
+ struct sock *sk = sk_vsock(vsk);
+ unsigned int remote_cid = vsk->remote_addr.svm_cid;
+ int ret;
+
+ switch (sk->sk_type) {
+ case SOCK_DGRAM:
+ new_transport = transport_dgram;
+ break;
+ case SOCK_STREAM:
+ if (remote_cid <= VMADDR_CID_HOST ||
+ (transport_g2h &&
+ remote_cid == transport_g2h->get_local_cid()))
+ new_transport = transport_g2h;
+ else
+ new_transport = transport_h2g;
+ break;
+ default:
+ return -ESOCKTNOSUPPORT;
+ }
+
+ if (vsk->transport) {
+ if (vsk->transport == new_transport)
+ return 0;
+
+ vsk->transport->release(vsk);
+ vsock_deassign_transport(vsk);
+ }
+
+ /* We increase the module refcnt to prevent the transport unloading
+ * while there are open sockets assigned to it.
+ */
+ if (!new_transport || !try_module_get(new_transport->module))
+ return -ENODEV;
+
+ ret = new_transport->init(vsk, psk);
+ if (ret) {
+ module_put(new_transport->module);
+ return ret;
+ }
+
+ vsk->transport = new_transport;
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+ if (transport_g2h && cid == transport_g2h->get_local_cid())
+ return true;
+
+ if (transport_h2g && cid == VMADDR_CID_HOST)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
static struct sock *vsock_dequeue_accept(struct sock *listener)
{
struct vsock_sock *vlistener;
@@ -418,7 +504,12 @@ static bool vsock_is_pending(struct sock *sk)
static int vsock_send_shutdown(struct sock *sk, int mode)
{
- return transport->shutdown(vsock_sk(sk), mode);
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ if (!vsk->transport)
+ return -ENODEV;
+
+ return vsk->transport->shutdown(vsk, mode);
}
static void vsock_pending_work(struct work_struct *work)
@@ -439,7 +530,7 @@ static void vsock_pending_work(struct work_struct *work)
if (vsock_is_pending(sk)) {
vsock_remove_pending(listener, sk);
- listener->sk_ack_backlog--;
+ sk_acceptq_removed(listener);
} else if (!vsk->rejected) {
/* We are not on the pending list and accept() did not reject
* us, so we must have been accepted by our user process. We
@@ -528,13 +619,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
- return transport->dgram_bind(vsk, addr);
+ return vsk->transport->dgram_bind(vsk, addr);
}
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
{
struct vsock_sock *vsk = vsock_sk(sk);
- u32 cid;
int retval;
/* First ensure this socket isn't already bound. */
@@ -544,10 +634,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
/* Now bind to the provided address or select appropriate values if
* none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that
* like AF_INET prevents binding to a non-local IP address (in most
- * cases), we only allow binding to the local CID.
+ * cases), we only allow binding to a local CID.
*/
- cid = transport->get_local_cid();
- if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+ if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
return -EADDRNOTAVAIL;
switch (sk->sk_socket->type) {
@@ -571,12 +660,12 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
static void vsock_connect_timeout(struct work_struct *work);
-struct sock *__vsock_create(struct net *net,
- struct socket *sock,
- struct sock *parent,
- gfp_t priority,
- unsigned short type,
- int kern)
+static struct sock *__vsock_create(struct net *net,
+ struct socket *sock,
+ struct sock *parent,
+ gfp_t priority,
+ unsigned short type,
+ int kern)
{
struct sock *sk;
struct vsock_sock *psk;
@@ -620,28 +709,24 @@ struct sock *__vsock_create(struct net *net,
vsk->trusted = psk->trusted;
vsk->owner = get_cred(psk->owner);
vsk->connect_timeout = psk->connect_timeout;
+ vsk->buffer_size = psk->buffer_size;
+ vsk->buffer_min_size = psk->buffer_min_size;
+ vsk->buffer_max_size = psk->buffer_max_size;
} else {
vsk->trusted = capable(CAP_NET_ADMIN);
vsk->owner = get_current_cred();
vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
+ vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
+ vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
+ vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
}
- if (transport->init(vsk, psk) < 0) {
- sk_free(sk);
- return NULL;
- }
-
- if (sock)
- vsock_insert_unbound(vsk);
-
return sk;
}
-EXPORT_SYMBOL_GPL(__vsock_create);
static void __vsock_release(struct sock *sk, int level)
{
if (sk) {
- struct sk_buff *skb;
struct sock *pending;
struct vsock_sock *vsk;
@@ -651,7 +736,10 @@ static void __vsock_release(struct sock *sk, int level)
/* The release call is supposed to use lock_sock_nested()
* rather than lock_sock(), if a sock lock should be acquired.
*/
- transport->release(vsk);
+ if (vsk->transport)
+ vsk->transport->release(vsk);
+ else if (sk->sk_type == SOCK_STREAM)
+ vsock_remove_sock(vsk);
/* When "level" is SINGLE_DEPTH_NESTING, use the nested
* version to avoid the warning "possible recursive locking
@@ -662,8 +750,7 @@ static void __vsock_release(struct sock *sk, int level)
sock_orphan(sk);
sk->sk_shutdown = SHUTDOWN_MASK;
- while ((skb = skb_dequeue(&sk->sk_receive_queue)))
- kfree_skb(skb);
+ skb_queue_purge(&sk->sk_receive_queue);
/* Clean up any sockets that never were accepted. */
while ((pending = vsock_dequeue_accept(sk)) != NULL) {
@@ -680,7 +767,7 @@ static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
- transport->destruct(vsk);
+ vsock_deassign_transport(vsk);
/* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel.
@@ -702,15 +789,22 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
return err;
}
+struct sock *vsock_create_connected(struct sock *parent)
+{
+ return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
+ parent->sk_type, 0);
+}
+EXPORT_SYMBOL_GPL(vsock_create_connected);
+
s64 vsock_stream_has_data(struct vsock_sock *vsk)
{
- return transport->stream_has_data(vsk);
+ return vsk->transport->stream_has_data(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_data);
s64 vsock_stream_has_space(struct vsock_sock *vsk)
{
- return transport->stream_has_space(vsk);
+ return vsk->transport->stream_has_space(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_space);
@@ -879,6 +973,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
} else if (sock->type == SOCK_STREAM) {
+ const struct vsock_transport *transport = vsk->transport;
lock_sock(sk);
/* Listening sockets that have connections in their accept
@@ -889,7 +984,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLIN | EPOLLRDNORM;
/* If there is something in the queue then we can read. */
- if (transport->stream_is_active(vsk) &&
+ if (transport && transport->stream_is_active(vsk) &&
!(sk->sk_shutdown & RCV_SHUTDOWN)) {
bool data_ready_now = false;
int ret = transport->notify_poll_in(
@@ -954,6 +1049,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr;
+ const struct vsock_transport *transport;
if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP;
@@ -962,6 +1058,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
err = 0;
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
lock_sock(sk);
@@ -1046,8 +1143,8 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;
- if (!transport->dgram_allow(remote_addr->svm_cid,
- remote_addr->svm_port)) {
+ if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+ remote_addr->svm_port)) {
err = -EINVAL;
goto out;
}
@@ -1063,7 +1160,9 @@ out:
static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
size_t len, int flags)
{
- return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags);
+ struct vsock_sock *vsk = vsock_sk(sock->sk);
+
+ return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
}
static const struct proto_ops vsock_dgram_ops = {
@@ -1089,6 +1188,8 @@ static const struct proto_ops vsock_dgram_ops = {
static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
{
+ const struct vsock_transport *transport = vsk->transport;
+
if (!transport->cancel_pkt)
return -EOPNOTSUPP;
@@ -1125,6 +1226,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
struct sockaddr_vm *remote_addr;
long timeout;
DEFINE_WAIT(wait);
@@ -1159,19 +1261,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
goto out;
}
+ /* Set the remote address that we are connecting to. */
+ memcpy(&vsk->remote_addr, remote_addr,
+ sizeof(vsk->remote_addr));
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err)
+ goto out;
+
+ transport = vsk->transport;
+
/* The hypervisor and well-known contexts do not have socket
* endpoints.
*/
- if (!transport->stream_allow(remote_addr->svm_cid,
+ if (!transport ||
+ !transport->stream_allow(remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -ENETUNREACH;
goto out;
}
- /* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
-
err = vsock_auto_bind(vsk);
if (err)
goto out;
@@ -1301,7 +1410,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
err = -listener->sk_err;
if (connected) {
- listener->sk_ack_backlog--;
+ sk_acceptq_removed(listener);
lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
vconnected = vsock_sk(connected);
@@ -1366,6 +1475,23 @@ out:
return err;
}
+static void vsock_update_buffer_size(struct vsock_sock *vsk,
+ const struct vsock_transport *transport,
+ u64 val)
+{
+ if (val > vsk->buffer_max_size)
+ val = vsk->buffer_max_size;
+
+ if (val < vsk->buffer_min_size)
+ val = vsk->buffer_min_size;
+
+ if (val != vsk->buffer_size &&
+ transport && transport->notify_buffer_size)
+ transport->notify_buffer_size(vsk, &val);
+
+ vsk->buffer_size = val;
+}
+
static int vsock_stream_setsockopt(struct socket *sock,
int level,
int optname,
@@ -1375,6 +1501,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
u64 val;
if (level != AF_VSOCK)
@@ -1395,23 +1522,26 @@ static int vsock_stream_setsockopt(struct socket *sock,
err = 0;
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
lock_sock(sk);
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
COPY_IN(val);
- transport->set_buffer_size(vsk, val);
+ vsock_update_buffer_size(vsk, transport, val);
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
COPY_IN(val);
- transport->set_max_buffer_size(vsk, val);
+ vsk->buffer_max_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
COPY_IN(val);
- transport->set_min_buffer_size(vsk, val);
+ vsk->buffer_min_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
@@ -1478,17 +1608,17 @@ static int vsock_stream_getsockopt(struct socket *sock,
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
- val = transport->get_buffer_size(vsk);
+ val = vsk->buffer_size;
COPY_OUT(val);
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
- val = transport->get_max_buffer_size(vsk);
+ val = vsk->buffer_max_size;
COPY_OUT(val);
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
- val = transport->get_min_buffer_size(vsk);
+ val = vsk->buffer_min_size;
COPY_OUT(val);
break;
@@ -1519,6 +1649,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
{
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
ssize_t total_written;
long timeout;
int err;
@@ -1527,6 +1658,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
total_written = 0;
err = 0;
@@ -1548,7 +1680,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (sk->sk_state != TCP_ESTABLISHED ||
+ if (!transport || sk->sk_state != TCP_ESTABLISHED ||
!vsock_addr_bound(&vsk->local_addr)) {
err = -ENOTCONN;
goto out;
@@ -1658,6 +1790,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
int err;
size_t target;
ssize_t copied;
@@ -1668,11 +1801,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
err = 0;
lock_sock(sk);
- if (sk->sk_state != TCP_ESTABLISHED) {
+ if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
* peer has not connected or a local shutdown occured with the
@@ -1846,6 +1980,10 @@ static const struct proto_ops vsock_stream_ops = {
static int vsock_create(struct net *net, struct socket *sock,
int protocol, int kern)
{
+ struct vsock_sock *vsk;
+ struct sock *sk;
+ int ret;
+
if (!sock)
return -EINVAL;
@@ -1865,7 +2003,23 @@ static int vsock_create(struct net *net, struct socket *sock,
sock->state = SS_UNCONNECTED;
- return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM;
+ sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
+ if (!sk)
+ return -ENOMEM;
+
+ vsk = vsock_sk(sk);
+
+ if (sock->type == SOCK_DGRAM) {
+ ret = vsock_assign_transport(vsk, NULL);
+ if (ret < 0) {
+ sock_put(sk);
+ return ret;
+ }
+ }
+
+ vsock_insert_unbound(vsk);
+
+ return 0;
}
static const struct net_proto_family vsock_family_ops = {
@@ -1878,11 +2032,20 @@ static long vsock_dev_do_ioctl(struct file *filp,
unsigned int cmd, void __user *ptr)
{
u32 __user *p = ptr;
+ u32 cid = VMADDR_CID_ANY;
int retval = 0;
switch (cmd) {
case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
- if (put_user(transport->get_local_cid(), p) != 0)
+ /* To be compatible with the VMCI behavior, we prioritize the
+ * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+ */
+ if (transport_g2h)
+ cid = transport_g2h->get_local_cid();
+ else if (transport_h2g)
+ cid = transport_h2g->get_local_cid();
+
+ if (put_user(cid, p) != 0)
retval = -EFAULT;
break;
@@ -1922,24 +2085,13 @@ static struct miscdevice vsock_device = {
.fops = &vsock_device_ops,
};
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
{
- int err = mutex_lock_interruptible(&vsock_register_mutex);
+ int err = 0;
- if (err)
- return err;
-
- if (transport) {
- err = -EBUSY;
- goto err_busy;
- }
-
- /* Transport must be the owner of the protocol so that it can't
- * unload while there are open sockets.
- */
- vsock_proto.owner = owner;
- transport = t;
+ vsock_init_tables();
+ vsock_proto.owner = THIS_MODULE;
vsock_device.minor = MISC_DYNAMIC_MINOR;
err = misc_register(&vsock_device);
if (err) {
@@ -1960,7 +2112,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
goto err_unregister_proto;
}
- mutex_unlock(&vsock_register_mutex);
return 0;
err_unregister_proto:
@@ -1968,44 +2119,86 @@ err_unregister_proto:
err_deregister_misc:
misc_deregister(&vsock_device);
err_reset_transport:
- transport = NULL;
-err_busy:
- mutex_unlock(&vsock_register_mutex);
return err;
}
-EXPORT_SYMBOL_GPL(__vsock_core_init);
-void vsock_core_exit(void)
+static void __exit vsock_exit(void)
{
- mutex_lock(&vsock_register_mutex);
-
misc_deregister(&vsock_device);
sock_unregister(AF_VSOCK);
proto_unregister(&vsock_proto);
-
- /* We do not want the assignment below re-ordered. */
- mb();
- transport = NULL;
-
- mutex_unlock(&vsock_register_mutex);
}
-EXPORT_SYMBOL_GPL(vsock_core_exit);
-const struct vsock_transport *vsock_core_get_transport(void)
+const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
{
- /* vsock_register_mutex not taken since only the transport uses this
- * function and only while registered.
- */
- return transport;
+ return vsk->transport;
}
EXPORT_SYMBOL_GPL(vsock_core_get_transport);
-static void __exit vsock_exit(void)
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+ const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+ int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+ if (err)
+ return err;
+
+ t_h2g = transport_h2g;
+ t_g2h = transport_g2h;
+ t_dgram = transport_dgram;
+
+ if (features & VSOCK_TRANSPORT_F_H2G) {
+ if (t_h2g) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_h2g = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_G2H) {
+ if (t_g2h) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_g2h = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_DGRAM) {
+ if (t_dgram) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_dgram = t;
+ }
+
+ transport_h2g = t_h2g;
+ transport_g2h = t_g2h;
+ transport_dgram = t_dgram;
+
+err_busy:
+ mutex_unlock(&vsock_register_mutex);
+ return err;
+}
+EXPORT_SYMBOL_GPL(vsock_core_register);
+
+void vsock_core_unregister(const struct vsock_transport *t)
{
- /* Do nothing. This function makes this module removable. */
+ mutex_lock(&vsock_register_mutex);
+
+ if (transport_h2g == t)
+ transport_h2g = NULL;
+
+ if (transport_g2h == t)
+ transport_g2h = NULL;
+
+ if (transport_dgram == t)
+ transport_dgram = NULL;
+
+ mutex_unlock(&vsock_register_mutex);
}
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
-module_init(vsock_init_tables);
+module_init(vsock_init);
module_exit(vsock_exit);
MODULE_AUTHOR("VMware, Inc.");