diff --git a/drivers/pci/ats.c b/drivers/pci/ats.c index 0b5b0ed7a436..9355f754c7c2 100644 --- a/drivers/pci/ats.c +++ b/drivers/pci/ats.c @@ -44,11 +44,12 @@ int pci_enable_ats(struct pci_dev *dev, int ps) u16 ctrl; struct pci_dev *pdev; - BUG_ON(dev->ats_cap && dev->ats_enabled); - if (!dev->ats_cap) return -EINVAL; + if (WARN_ON(pci_ats_enabled(dev))) + return -EBUSY; + if (ps < PCI_ATS_MIN_STU) return -EINVAL; @@ -83,7 +84,8 @@ void pci_disable_ats(struct pci_dev *dev) struct pci_dev *pdev; u16 ctrl; - BUG_ON(!dev->ats_cap || !dev->ats_enabled); + if (WARN_ON(!pci_ats_enabled(dev))) + return; if (atomic_read(&dev->ats_ref_cnt)) return; /* VFs still enabled */ @@ -107,8 +109,6 @@ void pci_restore_ats_state(struct pci_dev *dev) if (!pci_ats_enabled(dev)) return; - if (!pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ATS)) - BUG(); ctrl = PCI_ATS_CTRL_ENABLE; if (!dev->is_virtfn)