aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/net/vmw_vsock/hyperv_transport.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/hyperv_transport.c')
-rw-r--r--net/vmw_vsock/hyperv_transport.c26
1 files changed, 21 insertions, 5 deletions
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 22b608805a91..1c9e65d7d94d 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -165,6 +165,8 @@ static const guid_t srv_id_template =
GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
+static bool hvs_check_transport(struct vsock_sock *vsk);
+
static bool is_valid_srv_id(const guid_t *id)
{
return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
@@ -367,6 +369,18 @@ static void hvs_open_connection(struct vmbus_channel *chan)
new->sk_state = TCP_SYN_SENT;
vnew = vsock_sk(new);
+
+ hvs_addr_init(&vnew->local_addr, if_type);
+ hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
+
+ ret = vsock_assign_transport(vnew, vsock_sk(sk));
+ /* Transport assigned (looking at remote_addr) must be the
+ * same where we received the request.
+ */
+ if (ret || !hvs_check_transport(vnew)) {
+ sock_put(new);
+ goto out;
+ }
hvs_new = vnew->trans;
hvs_new->chan = chan;
} else {
@@ -430,9 +444,6 @@ static void hvs_open_connection(struct vmbus_channel *chan)
new->sk_state = TCP_ESTABLISHED;
sk_acceptq_added(sk);
- hvs_addr_init(&vnew->local_addr, if_type);
- hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
-
hvs_new->vm_srv_id = *if_type;
hvs_new->host_srv_id = *if_instance;
@@ -880,6 +891,11 @@ static struct vsock_transport hvs_transport = {
};
+static bool hvs_check_transport(struct vsock_sock *vsk)
+{
+ return vsk->transport == &hvs_transport;
+}
+
static int hvs_probe(struct hv_device *hdev,
const struct hv_vmbus_device_id *dev_id)
{
@@ -928,7 +944,7 @@ static int __init hvs_init(void)
if (ret != 0)
return ret;
- ret = vsock_core_init(&hvs_transport);
+ ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
if (ret) {
vmbus_driver_unregister(&hvs_drv);
return ret;
@@ -939,7 +955,7 @@ static int __init hvs_init(void)
static void __exit hvs_exit(void)
{
- vsock_core_exit();
+ vsock_core_unregister(&hvs_transport);
vmbus_driver_unregister(&hvs_drv);
}