diff options
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 65 |
1 files changed, 51 insertions, 14 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 74db4cd637a7..a5f28708e0e7 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -136,6 +136,8 @@ static const struct vsock_transport *transport_h2g; static const struct vsock_transport *transport_g2h; /* Transport used for DGRAM communication */ static const struct vsock_transport *transport_dgram; +/* Transport used for local communication */ +static const struct vsock_transport *transport_local; static DEFINE_MUTEX(vsock_register_mutex); /**** UTILS ****/ @@ -386,6 +388,21 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) } EXPORT_SYMBOL_GPL(vsock_enqueue_accept); +static bool vsock_use_local_transport(unsigned int remote_cid) +{ + if (!transport_local) + return false; + + if (remote_cid == VMADDR_CID_LOCAL) + return true; + + if (transport_g2h) { + return remote_cid == transport_g2h->get_local_cid(); + } else { + return remote_cid == VMADDR_CID_HOST; + } +} + static void vsock_deassign_transport(struct vsock_sock *vsk) { if (!vsk->transport) @@ -402,9 +419,9 @@ static void vsock_deassign_transport(struct vsock_sock *vsk) * (e.g. during the connect() or when a connection request on a listener * socket is received). * The vsk->remote_addr is used to decide which transport to use: + * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if + * g2h is not loaded, will use local transport; * - remote CID <= VMADDR_CID_HOST will use guest->host transport; - * - remote CID == local_cid (guest->host transport) will use guest->host - * transport for loopback (host->guest transports don't support loopback); * - remote CID > VMADDR_CID_HOST will use host->guest transport; */ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) @@ -419,9 +436,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) new_transport = transport_dgram; break; case SOCK_STREAM: - if (remote_cid <= VMADDR_CID_HOST || - (transport_g2h && - remote_cid == transport_g2h->get_local_cid())) + if (vsock_use_local_transport(remote_cid)) + new_transport = transport_local; + else if (remote_cid <= VMADDR_CID_HOST) new_transport = transport_g2h; else new_transport = transport_h2g; @@ -434,6 +451,12 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) if (vsk->transport == new_transport) return 0; + /* transport->release() must be called with sock lock acquired. + * This path can only be taken during vsock_stream_connect(), + * where we have already held the sock lock. + * In the other cases, this function is called on a new socket + * which is not assigned to any transport. + */ vsk->transport->release(vsk); vsock_deassign_transport(vsk); } @@ -464,6 +487,9 @@ bool vsock_find_cid(unsigned int cid) if (transport_h2g && cid == VMADDR_CID_HOST) return true; + if (transport_local && cid == VMADDR_CID_LOCAL) + return true; + return false; } EXPORT_SYMBOL_GPL(vsock_find_cid); @@ -733,20 +759,18 @@ static void __vsock_release(struct sock *sk, int level) vsk = vsock_sk(sk); pending = NULL; /* Compiler warning. */ - /* The release call is supposed to use lock_sock_nested() - * rather than lock_sock(), if a sock lock should be acquired. - */ - if (vsk->transport) - vsk->transport->release(vsk); - else if (sk->sk_type == SOCK_STREAM) - vsock_remove_sock(vsk); - /* When "level" is SINGLE_DEPTH_NESTING, use the nested * version to avoid the warning "possible recursive locking * detected". When "level" is 0, lock_sock_nested(sk, level) * is the same as lock_sock(sk). */ lock_sock_nested(sk, level); + + if (vsk->transport) + vsk->transport->release(vsk); + else if (sk->sk_type == SOCK_STREAM) + vsock_remove_sock(vsk); + sock_orphan(sk); sk->sk_shutdown = SHUTDOWN_MASK; @@ -2137,7 +2161,7 @@ EXPORT_SYMBOL_GPL(vsock_core_get_transport); int vsock_core_register(const struct vsock_transport *t, int features) { - const struct vsock_transport *t_h2g, *t_g2h, *t_dgram; + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local; int err = mutex_lock_interruptible(&vsock_register_mutex); if (err) @@ -2146,6 +2170,7 @@ int vsock_core_register(const struct vsock_transport *t, int features) t_h2g = transport_h2g; t_g2h = transport_g2h; t_dgram = transport_dgram; + t_local = transport_local; if (features & VSOCK_TRANSPORT_F_H2G) { if (t_h2g) { @@ -2171,9 +2196,18 @@ int vsock_core_register(const struct vsock_transport *t, int features) t_dgram = t; } + if (features & VSOCK_TRANSPORT_F_LOCAL) { + if (t_local) { + err = -EBUSY; + goto err_busy; + } + t_local = t; + } + transport_h2g = t_h2g; transport_g2h = t_g2h; transport_dgram = t_dgram; + transport_local = t_local; err_busy: mutex_unlock(&vsock_register_mutex); @@ -2194,6 +2228,9 @@ void vsock_core_unregister(const struct vsock_transport *t) if (transport_dgram == t) transport_dgram = NULL; + if (transport_local == t) + transport_local = NULL; + mutex_unlock(&vsock_register_mutex); } EXPORT_SYMBOL_GPL(vsock_core_unregister); |