aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vfio/vfio_iommu_type1.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vfio/vfio_iommu_type1.c')
-rw-r--r--drivers/vfio/vfio_iommu_type1.c84
1 files changed, 77 insertions, 7 deletions
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index d0f731c9920a..c956b85264ae 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -97,6 +97,7 @@ struct vfio_dma {
struct vfio_group {
struct iommu_group *iommu_group;
struct list_head next;
+ bool mdev_group; /* An mdev group */
};
/*
@@ -1311,6 +1312,75 @@ static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
return ret;
}
+static struct device *vfio_mdev_get_iommu_device(struct device *dev)
+{
+ struct device *(*fn)(struct device *dev);
+ struct device *iommu_device;
+
+ fn = symbol_get(mdev_get_iommu_device);
+ if (fn) {
+ iommu_device = fn(dev);
+ symbol_put(mdev_get_iommu_device);
+
+ return iommu_device;
+ }
+
+ return NULL;
+}
+
+static int vfio_mdev_attach_domain(struct device *dev, void *data)
+{
+ struct iommu_domain *domain = data;
+ struct device *iommu_device;
+
+ iommu_device = vfio_mdev_get_iommu_device(dev);
+ if (iommu_device) {
+ if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
+ return iommu_aux_attach_device(domain, iommu_device);
+ else
+ return iommu_attach_device(domain, iommu_device);
+ }
+
+ return -EINVAL;
+}
+
+static int vfio_mdev_detach_domain(struct device *dev, void *data)
+{
+ struct iommu_domain *domain = data;
+ struct device *iommu_device;
+
+ iommu_device = vfio_mdev_get_iommu_device(dev);
+ if (iommu_device) {
+ if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
+ iommu_aux_detach_device(domain, iommu_device);
+ else
+ iommu_detach_device(domain, iommu_device);
+ }
+
+ return 0;
+}
+
+static int vfio_iommu_attach_group(struct vfio_domain *domain,
+ struct vfio_group *group)
+{
+ if (group->mdev_group)
+ return iommu_group_for_each_dev(group->iommu_group,
+ domain->domain,
+ vfio_mdev_attach_domain);
+ else
+ return iommu_attach_group(domain->domain, group->iommu_group);
+}
+
+static void vfio_iommu_detach_group(struct vfio_domain *domain,
+ struct vfio_group *group)
+{
+ if (group->mdev_group)
+ iommu_group_for_each_dev(group->iommu_group, domain->domain,
+ vfio_mdev_detach_domain);
+ else
+ iommu_detach_group(domain->domain, group->iommu_group);
+}
+
static int vfio_iommu_type1_attach_group(void *iommu_data,
struct iommu_group *iommu_group)
{
@@ -1386,7 +1456,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
goto out_domain;
}
- ret = iommu_attach_group(domain->domain, iommu_group);
+ ret = vfio_iommu_attach_group(domain, group);
if (ret)
goto out_domain;
@@ -1418,8 +1488,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
list_for_each_entry(d, &iommu->domain_list, next) {
if (d->domain->ops == domain->domain->ops &&
d->prot == domain->prot) {
- iommu_detach_group(domain->domain, iommu_group);
- if (!iommu_attach_group(d->domain, iommu_group)) {
+ vfio_iommu_detach_group(domain, group);
+ if (!vfio_iommu_attach_group(d, group)) {
list_add(&group->next, &d->group_list);
iommu_domain_free(domain->domain);
kfree(domain);
@@ -1427,7 +1497,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
return 0;
}
- ret = iommu_attach_group(domain->domain, iommu_group);
+ ret = vfio_iommu_attach_group(domain, group);
if (ret)
goto out_domain;
}
@@ -1453,7 +1523,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
return 0;
out_detach:
- iommu_detach_group(domain->domain, iommu_group);
+ vfio_iommu_detach_group(domain, group);
out_domain:
iommu_domain_free(domain->domain);
out_free:
@@ -1544,7 +1614,7 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
if (!group)
continue;
- iommu_detach_group(domain->domain, iommu_group);
+ vfio_iommu_detach_group(domain, group);
list_del(&group->next);
kfree(group);
/*
@@ -1610,7 +1680,7 @@ static void vfio_release_domain(struct vfio_domain *domain, bool external)
list_for_each_entry_safe(group, group_tmp,
&domain->group_list, next) {
if (!external)
- iommu_detach_group(domain->domain, group->iommu_group);
+ vfio_iommu_detach_group(domain, group);
list_del(&group->next);
kfree(group);
}