#!/usr/bin/env python3
# SPDX-License-Identifier: Linux-OpenIB
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
# PYTHON_ARGCOMPLETE_OK
from __future__ import annotations
import argparse
import collections
import copy
import importlib
import inspect
import itertools
import json
import os
import re
import subprocess
import sys
import tempfile

from abc import ABC, abstractmethod
from base64 import b64encode, b64decode
from typing import *
from zlib import compress, decompress

DEVDIR = os.environ.get("RDMA_TOPO_DEVDIR", "/sys/bus/pci/devices/")

BDF_RE = re.compile(r"^([0-9a-f]+?):([0-9a-f]{2}?):([0-9a-f]{2}?)\.([0-9a-f])$")
KERNEL_ACS_ISOLATED = "xx111x1"
pci_vendors = {
    "MELLANOX": 0x15B3,
    "NVIDIA": 0x10DE,
}

PCI_EXT_CAP_ID_ACS = 0x000D
PCI_EXT_CAP_ID_ATS = 0x000F

PCI_VPD_LRDT      = 0x80  # Large Resource Data Type flag
PCI_VPD_END_SMALL = 0x78  # Small Resource End Tag
PCI_VPD_END_LARGE = 0x79  # Large Resource End Tag
PCI_VPD_LRDT_ID   = 0x82  # Identifier String
PCI_VPD_LRDT_RO   = 0x90  # VPD-R (Read-Only)


class CommandError(Exception):
    pass


TOPO_NOT_SUPPORTED = CommandError("No supported topology detected")


def yesno(b: bool) -> str:
    return "yes" if b else "no"


class SysfsDevice(object):
    REQUIRED_KEYS = ["realpath", "config", "modalias"]
    ENCODED_KEYS = ["config", "vpd"]

    @property
    def realpath(self) -> str:
        return self.data["realpath"]

    @property
    def config(self) -> bytes:
        return self.data["config"]

    @property
    def iommu_group(self) -> Optional[int]:
        return self.data.get("iommu_group", None)

    @property
    def modalias(self) -> str:
        return self.data["modalias"]

    @property
    def numa_node(self) -> Optional[int]:
        return self.data.get("numa_node", None)

    @property
    def vpd(self) -> Optional[bytes]:
        return self.data.get("vpd", None)

    @property
    def subsystems(self) -> Optional[Dict[str, List[str]]]:
        return self.data.get("subsystems", None)

    @property
    def id(self) -> str:
        return os.path.basename(self.data["realpath"])

    def __init__(self, id: str):
        def read(*parts: str) -> bytes:
            with open(os.path.join(devdir, *parts), "rb") as F:
                return F.read()

        def string(b: bytes) -> str:
            return b.decode("ascii").strip()

        def subsystems() -> Dict[str, List[str]]:
            res: Dict[str, List[str]] = collections.defaultdict(list)
            for fn in os.listdir(devdir):
                if fn in {"drm", "infiniband", "net", "nvme"}:
                    res[fn].extend(os.listdir(os.path.join(devdir, fn)))
            return dict(res)

        def iommu_group() -> int:
            return int(
                os.path.basename(os.readlink(os.path.join(devdir, "iommu_group")))
            )

        devdir = os.path.join(DEVDIR, id)

        readers = {
            "realpath": lambda: os.path.realpath(devdir),
            "config": lambda: read("config"),
            "iommu_group": iommu_group,
            "modalias": lambda: string(read("modalias")),
            "numa_node": lambda: int(string(read("numa_node"))),
            "vpd": lambda: read("vpd"),
            "subsystems": subsystems,
        }

        self.data: Dict[str, Any] = {}
        for k, reader in readers.items():
            try:
                self.data[k] = reader()
            except FileNotFoundError as e:
                if k in SysfsDevice.REQUIRED_KEYS:
                    raise CommandError(f"Missing required sysfs path: {e.filename}")
                self.data[k] = None
            except PermissionError as e:
                raise CommandError(
                    f"Cannot read sysfs path: {e.filename}. Are you root?"
                )

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> SysfsDevice:
        obj = object.__new__(cls)

        obj.data = copy.deepcopy(data)
        for k in SysfsDevice.REQUIRED_KEYS:
            if k not in obj.data or obj.data[k] is None:
                raise ValueError(f"Missing required key '{k}'")

        for k in SysfsDevice.ENCODED_KEYS:
            if k in obj.data and obj.data[k] is not None:
                try:
                    obj.data[k] = decompress(b64decode(obj.data[k]))
                except Exception as e:
                    raise ValueError(f"Invalid encoded value for key '{k}': {e}")

        return obj

    def to_dict(self) -> Dict[str, Any]:
        res = copy.deepcopy(self.data)
        for k in SysfsDevice.ENCODED_KEYS:
            if k not in res:
                continue
            if res[k] is not None:
                res[k] = b64encode(compress(res[k])).decode("ascii")
            else:
                del res[k]
        return res


def parse_vpd(vpd: Optional[bytes]) -> Tuple[Optional[str], Optional[str]]:
    """Parse VPD name and V3 UUID"""
    if vpd is None:
        return None, None

    name = None
    v3 = None

    def items(data: bytes) -> Generator[Tuple[int, bytes]]:
        while len(data) > 0:
            tag = data[0]
            if tag in [PCI_VPD_END_SMALL, PCI_VPD_END_LARGE]:
                break

            if tag & PCI_VPD_LRDT:
                offset = 3
                if len(data) < 3:
                    break
                length = int.from_bytes(data[1:3], "little")
            else:
                offset = 1
                length = tag & 0x07

            if length > len(data) - offset:
                break

            yield (tag, data[offset : offset + length])
            data = data[offset + length :]

    def keywords(data: bytes) -> Generator[Tuple[str, bytes]]:
        while len(data) >= 4:
            length = int(data[2])
            if length > len(data) - 3:
                break
            yield (data[:2].decode("ascii"), data[3 : 3 + length])
            data = data[3 + length :]

    try:
        for tag, item in items(vpd):
            if tag == PCI_VPD_LRDT_ID:
                name = item.decode("ascii").strip()
            if tag ==  PCI_VPD_LRDT_RO:
                for keyword, value in keywords(item):
                    if keyword == "V3":
                        v3 = value.decode("ascii")
    except UnicodeDecodeError:
        pass

    return (v3, name)


def parse_ext_cap(config: bytes, cap_id: int) -> Optional[bytes]:
    """Parse an extended capability from the PCI configuration space"""
    if len(config) < 0x104:
        return None

    offset = 0x100
    while offset and offset < len(config) - 4:
        header = int.from_bytes(config[offset : offset + 4], "little")
        next_offset = (header >> 20) & 0xFFC
        if next_offset == 0 or next_offset <= offset + 4 or next_offset > len(config):
            next_offset = len(config)
        if (header & 0xFFFF) == cap_id:
            return config[offset + 4 : next_offset]
        if next_offset == len(config):
            break
        offset = next_offset

    return None


def parse_acs_ctrl(config: bytes) -> Optional[int]:
    """Parse the ACS control register from the PCI configuration space"""
    raw = parse_ext_cap(config, PCI_EXT_CAP_ID_ACS)

    if raw is None or len(raw) < 4:
        return None

    return int.from_bytes(raw[2:4], "little")


def has_ats_cap(config: bytes) -> bool:
    """True if the device exposes an ATS capability"""
    return parse_ext_cap(config, PCI_EXT_CAP_ID_ATS) is not None


def PCI_VDEVICE(vendor: str, device_id: int) -> re.Pattern:
    """Match a Vendor and device ID"""
    vendor_id = pci_vendors[vendor]
    return re.compile(rf"^pci:v{vendor_id:08X}d{device_id:08X}.*$")


def PCI_DEVICE_CLASS(cid: int) -> re.Pattern:
    """Match by exact programming class using the int coding from the kernel"""
    class_id = (cid >> 16) & 0xFF
    subclass_id = (cid >> 8) & 0xFF
    progif = cid & 0xFF
    return re.compile(rf"^pci:.*bc{class_id:02X}sc{subclass_id:02X}i{progif:02X}.*$")


def PCI_NVGPU() -> re.Pattern:
    """Match all NVIDIA GPUs"""
    vendor_id = pci_vendors["NVIDIA"]
    class_id = 0x03
    return re.compile(rf"^pci:v{vendor_id:08X}.*bc{class_id:02X}.*$")


# Table of modalias matches to the device_type string.
# Order is important. The first matching device type is used.
pci_device_types = {
    PCI_VDEVICE("NVIDIA", 0x22B1): "grace_rp",  # NVIDIA Grace PCI Root Port Bridge
    PCI_VDEVICE("NVIDIA", 0x22B2): "grace_rp",  # NVIDIA Grace PCI Root Port Bridge
    PCI_VDEVICE("NVIDIA", 0x22B8): "grace_rp",  # NVIDIA Grace PCI Root Port Bridge
    PCI_VDEVICE("MELLANOX", 0x1021): "cx_nic",  # ConnectX-7
    PCI_VDEVICE("MELLANOX", 0x1023): "cx_nic",  # ConnectX-8
    PCI_VDEVICE("MELLANOX", 0xA2DC): "bf3_nic",  # BlueField-3
    PCI_VDEVICE("MELLANOX", 0x2100): "cx_dma",  # ConnectX-8 DMA Controller
    PCI_VDEVICE("MELLANOX", 0x197B): "bf3_switch",  # USP/DSP of a BF3 switch
    PCI_VDEVICE("MELLANOX", 0x197C): "cx_switch",  # USP/DSP of a CX switch
    PCI_VDEVICE("MELLANOX", 0x1979): "cx_switch",  # USP/DSP of a CX switch
    PCI_DEVICE_CLASS(0x010802): "nvme",
    PCI_NVGPU(): "nvgpu",
    PCI_DEVICE_CLASS(0x060400): "bridge",  # Generic PCI-PCI bridge / Root Port
}


dump_ignored = [
    PCI_DEVICE_CLASS(0x060000),  # Generic system peripheral
    PCI_DEVICE_CLASS(0x060100),  # ISA bridge
    PCI_DEVICE_CLASS(0x080700),  # Non-Essential Instrumentation
    PCI_DEVICE_CLASS(0x088000),  # System peripheral
    PCI_DEVICE_CLASS(0x110100),  # Performance counters
    PCI_DEVICE_CLASS(0x130000),  # Non-Essential Instrumentation
]


class PCIBDF(
    collections.namedtuple("PCIBDF", ["segment", "bus", "device", "function"])
):
    """Bus Device Function for a PCI device"""

    def as_pci(self):
        return f"{self.segment}:{self.bus}:{self.device}.{self.function}"

    def __str__(self):
        return self.as_pci()

    def __repr__(self):
        return f"PCIBDF({self.segment}, {self.bus}, {self.device}, {self.function})"


def to_pcibdf(s: str) -> Optional[PCIBDF]:
    g = BDF_RE.match(s)
    if not g:
        return None
    return PCIBDF(*g.groups())


class PCIDevice(object):
    device_type = ""
    vpd_v3: Optional[str] = None
    vpd_name: Optional[str] = None
    parent: PCIDevice = None

    def __init__(self, bdf: PCIBDF, sysfs_device: SysfsDevice):
        self.bdf = bdf
        self.sysfs_device = sysfs_device

        self.iommu_group = self.sysfs_device.iommu_group
        self.numa_node = self.sysfs_device.numa_node
        self.modalias = self.sysfs_device.modalias

        parent = os.path.basename(os.path.dirname(self.sysfs_device.realpath))
        self.parent_bdf = to_pcibdf(parent)

        for k, v in pci_device_types.items():
            if k.match(self.modalias):
                if self.parent_bdf is None and v == "bridge":
                    v = "generic_rp"
                self.device_type = v
                break

        self.children: Set[PCIDevice] = set()
        self.has_ats = False

    def finish_loading(self):
        """Do more expensive parsing operations"""
        if self.device_type == "cx_nic" or self.device_type == "cx_dma":
            self.vpd_v3, self.vpd_name = parse_vpd(self.sysfs_device.vpd)
        if "switch" in self.device_type or self.device_type.endswith("_rp"):
            self.has_acs = self.get_acs_ctrl() is not None
        if self.device_type == "cx_nic":
            self.has_ats = has_ats_cap(self.sysfs_device.config)

    def iterdownstream(self) -> Generator[PCIDevice, None, None]:
        """Iterate over all downstream devices of this device recursively"""
        for pdev in self.children:
            yield pdev
            yield from pdev.iterdownstream()

    def iterfulltree(self):
        for pdev in self.iterupstream_path():
            if not pdev.parent:
                yield from pdev.iterdownstream()

    def iterupstream_path(self):
        """Iterate over each step along the upstream path from the devices
        parent to the root."""
        pdev = self.parent
        while pdev:
            yield pdev
            pdev = pdev.parent

    def __repr__(self):
        return f"PCIDevice({self.bdf})"

    def get_acs_ctrl(self):
        """Read the ACS control register from the PCI configuration space"""
        return parse_acs_ctrl(self.sysfs_device.config)

    def get_subsystems(self):
        """Return a list of subsystem the PCI device is connected to"""
        return self.sysfs_device.subsystems or {}


class NVCX_Complex(ABC):
    @property
    @abstractmethod
    def primary_nic(self) -> PCIDevice:
        """Primary ConnectX PF for this complex."""
        pass

    @abstractmethod
    def compute_acs(self, virt: Optional[bool]) -> Dict[PCIDevice, str]:
        """Computes the ACS values for this complex.

        Used to implement commands which check and/or set ACS values.
        """
        pass

    @abstractmethod
    def to_dict(self) -> Dict[str, Any]:
        """Returns a JSON-serializable dictionary which represents this complex.

        Used to implement topology dump command with `-j / --json` flag.

        Output format should be maintained for backwards compatibility.
        """
        pass

    @abstractmethod
    def check(self, virt: Optional[bool]) -> bool:
        """Runs additional checks on this complex.

        Returns True if all checks pass, False otherwise.

        Used to implement the `check` command.
        """
        pass

    @abstractmethod
    def __str__(self) -> str:
        """Returns a string representation of this complex.

        Used to implement the `topo` command.
        """
        pass


class NVCX_DMA_Complex(NVCX_Complex):
    """Hold the related PCI functions together. A complex includes a CX PF, a CX
    DMA function, an GPU and related PCI switches in the DMA function
    segment."""

    def __init__(self, cx_pfs: Set[PCIDevice], cx_dma: PCIDevice, nvgpu: PCIDevice):
        self.cx_pfs = cx_pfs - {cx_dma}
        self.cx_pf = sorted(self.cx_pfs, key=lambda x: x.bdf)[0]
        self.cx_dma = cx_dma
        self.nvgpu = nvgpu

        # Identify the switch ports that are part of the shared path that
        # handles the P2P traffic
        self.shared_usp = self.__find_shared_usp()
        for pdev in self.cx_dma.iterupstream_path():
            if pdev in self.shared_usp.children:
                self.cx_dma_dsp = pdev
        for pdev in self.nvgpu.iterupstream_path():
            if pdev in self.shared_usp.children:
                self.nvgpu_dsp = pdev

        # There can be a NVMe device connected to the CX NIC as well. For NVMe
        # it is best to match with GPUs on the same socket, so a NUMA aware
        # approach would be fine, but also the GPU/NIC/NVMe could be
        # consistently paired based on the physical layout.
        self.nvmes: Set[PCIDevice] = set()
        for pdev in self.cx_pf.iterfulltree():
            if pdev.device_type == "nvme":
                self.nvmes.add(pdev)

    @property
    def primary_nic(self) -> PCIDevice:
        return self.cx_pf

    def __find_shared_usp(self) -> PCIDevice:
        """Find the USP that is shared by both devices, the immediate downstream
        bus is the point in the topology where P2P traffic will switch from an
        upstream to downstream direction."""
        common_path = set(self.cx_dma.iterupstream_path()).intersection(
            set(self.nvgpu.iterupstream_path())
        )
        assert common_path

        for pdev in self.cx_dma.iterupstream_path():
            if pdev in common_path:
                assert pdev.device_type == "cx_switch"
                for i in pdev.children:
                    assert i.device_type == "cx_switch"
                return pdev

    def compute_acs(self, _: Optional[bool]) -> Dict[PCIDevice, str]:
        acs: Dict[PCIDevice, str] = {}

        # For the DSP in the shared switch toward the CX8 DMA Direct interface:
        # Enable these bits:
        # bit-4 : ACS Upstream Forwarding
        # bit-3 : ACS P2P Completion Redirect
        # bit-0 : ACS Source Validation
        # Disable these bits:
        # bit-2 : ACS P2P Request Redirect
        assert self.cx_dma_dsp.has_acs
        acs[self.cx_dma_dsp] = "xx110x1"

        # For the DSP in the shared switch toward the GPU:
        # Enable the following bits:
        # bit-4 : ACS Upstream Forwarding
        # bit-2 : ACS P2P Request Redirect
        # bit-0 : ACS Source Validation
        # Disable the following bits:
        # bit-3 : ACS P2P Completion Redirect
        assert self.nvgpu_dsp.has_acs
        acs[self.nvgpu_dsp] = "xx101x1"

        # Disable ACS SV on the root port, this forces the entire segment
        # into one iommu_group and avoids kernel bugs building groups for
        # irregular ACS.
        for pdev in self.cx_dma_dsp.iterupstream_path():
            if not pdev.parent:
                assert pdev.has_acs
                acs[pdev] = "xx111x0"

        return acs

    def to_dict(self) -> Dict[str, Any]:
        res = {
            "rdma_nic_pf_bdf": str(self.cx_pf.bdf),
            "rdma_dma_bdf": str(self.cx_dma.bdf),
            "gpu_bdf": str(self.nvgpu.bdf),
            "subsystems": {},
        }
        devname = self.cx_pf.vpd_name
        if devname:
            res["rdma_nic_vpd_name"] = self.cx_pf.vpd_name
        if self.cx_pf.numa_node is not None:
            res["numa_node"] = self.cx_pf.numa_node
        if self.nvmes:
            res["nvme_bdf"] = str(next(iter(self.nvmes)).bdf)

        for pdev in sorted(
            itertools.chain(self.cx_pfs, [self.nvgpu, self.cx_dma], self.nvmes),
            key=lambda x: x.bdf,
        ):
            subsys = pdev.get_subsystems()
            if subsys:
                res["subsystems"][str(pdev.bdf)] = {
                    subsys: list(devs) for subsys, devs in subsys.items()
                }
        return res

    def __str__(self):
        res = f"RDMA NIC={self.cx_pf.bdf}, GPU={self.nvgpu.bdf}, RDMA DMA Function={self.cx_dma.bdf}\n"
        devname = self.cx_pf.vpd_name
        if devname:
            res += f"\t{devname}\n"

        if self.cx_pf.numa_node is not None:
            res += f"\tNUMA Node: {self.cx_pf.numa_node}\n"

        if len(self.cx_pfs):
            res += print_list("NIC PCI device", [str(I.bdf) for I in self.cx_pfs])

        subsystems: Dict[str, Set[str]] = collections.defaultdict(set)
        for pdev in itertools.chain(self.cx_pfs, [self.nvgpu, self.cx_dma], self.nvmes):
            for k, v in pdev.get_subsystems().items():
                subsystems[k].update(v)
        res += print_list("RDMA device", subsystems["infiniband"])
        res += print_list("Net device", subsystems["net"])
        res += print_list("DRM device", subsystems["drm"])
        res += print_list("NVMe device", subsystems["nvme"])

        return res[:-1]

    def check(self, _: Optional[bool]) -> bool:
        # Correct iommu_groups are required to avoid NVreg_GrdmaPciTopoCheckOverride
        if (
            self.cx_dma.iommu_group == self.nvgpu.iommu_group
            and self.cx_dma.iommu_group is not None
        ):
            check_ok(
                f"Kernel iommu_group for DMA {self.cx_dma.bdf} and GPU {self.nvgpu.bdf} are both {self.cx_dma.iommu_group}"
            )
            return True

        check_fail(
            f"Kernel iommu_group for DMA {self.cx_dma.bdf} and GPU {self.nvgpu.bdf} are not equal {self.cx_dma.iommu_group} != {self.nvgpu.iommu_group}"
        )
        return False


class NVCX_Inline_Complex(NVCX_Complex):
    def __init__(
        self,
        root_port: PCIDevice,
        shared_usp: PCIDevice,
        cx_pf: PCIDevice,
        nvgpu: PCIDevice,
    ):
        self.root_port = root_port
        self.cx_pf = cx_pf
        self.nvgpu = nvgpu
        self.cx_pf_dsp = None
        self.nvgpu_dsp = None

        for dsp in shared_usp.children:
            for pdev in dsp.iterdownstream():
                if pdev.device_type == "cx_nic":
                    if self.cx_pf_dsp is not None:
                        raise ValueError(
                            f"Multiple CX NIC DSPs under the same shared switch not supported"
                        )
                    self.cx_pf_dsp = dsp
                    break
                if pdev.device_type == "nvgpu":
                    if self.nvgpu_dsp is not None:
                        raise ValueError(
                            f"Multiple GPU DSPs under the same shared switch not supported"
                        )
                    self.nvgpu_dsp = dsp
                    break

        if not self.cx_pf_dsp:
            raise ValueError(f"CX NIC DSP not found in the topology")
        if not self.nvgpu_dsp:
            raise ValueError(f"GPU DSP not found in the topology")

    @property
    def primary_nic(self) -> PCIDevice:
        return self.cx_pf

    def compute_acs(self, virt: Optional[bool]) -> Dict[PCIDevice, str]:
        if not self.cx_pf_dsp.has_acs:
            raise CommandError(f"CX NIC DSP {self.cx_pf_dsp.bdf} lacks ACS")
        if not self.nvgpu_dsp.has_acs:
            raise CommandError(f"GPU DSP {self.nvgpu_dsp.bdf} lacks ACS")
        if not self.root_port.has_acs:
            raise CommandError(f"Root port {self.root_port.bdf} lacks ACS")
        if virt is None:
            raise CommandError("Unexpected: Could not determine virt mode")

        if virt:
            return {
                # The DSPs of the NIC which is non DD in the shared switch should
                # have the following enabled:
                # bit-6 : ACS Direct Translated P2P
                # bit-4 : ACS Upstream Forwarding
                # bit-3 : ACS P2P Completion Redirect
                # bit-2 : ACS P2P Request Redirect
                # bit-0 : ACS Source Validation
                self.cx_pf_dsp: "1x111x1",
                # The DSPs of the GPU in the shared switch and the RP of the NIC/GPU
                # should have the following enabled, matching the kernel default:
                # bit-4 : ACS Upstream Forwarding
                # bit-3 : ACS P2P Completion Redirect
                # bit-2 : ACS P2P Request Redirect
                # bit-0 : ACS Source Validation
                self.nvgpu_dsp: KERNEL_ACS_ISOLATED,
                self.root_port: KERNEL_ACS_ISOLATED,
            }
        else:
            return {
                # The DSPs of both the NIC and GPU in the shared switch and
                # RPs of the NIC/GPU should have the following disabled:
                # bit-4 : ACS Upstream Forwarding
                # bit-3 : ACS P2P Completion Redirect
                # bit-2 : ACS P2P Request Redirect
                # bit-0 : ACS Source Validation
                self.cx_pf_dsp: "xx000x0",
                self.nvgpu_dsp: "xx000x0",
                self.root_port: "xx000x0",
            }

    def to_dict(self) -> Dict[str, Any]:
        res = {
            "rdma_nic_pf_bdf": str(self.cx_pf.bdf),
            "gpu_bdf": str(self.nvgpu.bdf),
            "subsystems": {},
        }
        devname = self.cx_pf.vpd_name
        if devname:
            res["rdma_nic_vpd_name"] = self.cx_pf.vpd_name
        if self.cx_pf.numa_node is not None:
            res["numa_node"] = self.cx_pf.numa_node
        if self.cx_pf.has_ats:
            res["rdma_nic_ats"] = self.cx_pf.has_ats

        for pdev in sorted(
            itertools.chain([self.cx_pf, self.nvgpu]),
            key=lambda x: x.bdf,
        ):
            subsys = pdev.get_subsystems()
            if subsys:
                res["subsystems"][str(pdev.bdf)] = {
                    subsys: list(devs) for subsys, devs in subsys.items()
                }
        return res

    def __check_ats(self, virt: bool) -> bool:
        status = "available" if self.cx_pf.has_ats else "not available"
        msg = f"ATS capability for {self.cx_pf.device_type} {self.cx_pf.bdf} is {status}"

        if self.cx_pf.has_ats != virt:
            check_fail(msg)
            return False

        check_ok(msg)
        return True

    def __check_iommu_group(self, virt: bool) -> bool:
        cxpf = f"{self.cx_pf.device_type} {self.cx_pf.bdf}"
        nvgpu = f"{self.nvgpu.device_type} {self.nvgpu.bdf}"
        prefix = f"Kernel iommu_group for {cxpf} and {nvgpu}"

        equal = f"equal {self.cx_pf.iommu_group} == {self.nvgpu.iommu_group}"
        not_equal = f"not equal {self.cx_pf.iommu_group} != {self.nvgpu.iommu_group}"

        if virt:
            if self.cx_pf.iommu_group is None:
                check_fail(f"Kernel iommu_group is missing for {cxpf}")
                return False

            if self.nvgpu.iommu_group is None:
                check_fail(f"Kernel iommu_group is missing for {nvgpu}")
                return False

            if self.cx_pf.iommu_group == self.nvgpu.iommu_group:
                check_fail(f"{prefix} are {equal}")
                return False

            check_ok(f"{prefix} are {not_equal}")
            return True
        else:
            if self.cx_pf.iommu_group is None and self.nvgpu.iommu_group is None:
                check_ok(f"{prefix} are not set")
                return True

            if self.cx_pf.iommu_group != self.nvgpu.iommu_group:
                check_fail(f"{prefix} are {not_equal}")
                return False

            check_ok(f"{prefix} are {equal}")
            return True

    def check(self, virt: Optional[bool]) -> bool:
        assert virt is not None
        res_ats = self.__check_ats(virt)
        res_iommu_group = self.__check_iommu_group(virt)
        return res_ats and res_iommu_group

    def __str__(self):
        res = f"RDMA NIC={self.cx_pf.bdf}, GPU={self.nvgpu.bdf}\n"
        devname = self.cx_pf.vpd_name
        if devname:
            res += f"\t{devname}\n"
        if self.cx_pf.numa_node is not None:
            res += f"\tNUMA Node: {self.cx_pf.numa_node}\n"

        res += f"\tNIC ATS: {yesno(self.cx_pf.has_ats)}\n"

        subsystems: Dict[str, Set[str]] = collections.defaultdict(set)
        for pdev in [self.cx_pf, self.nvgpu]:
            for k, v in pdev.get_subsystems().items():
                subsystems[k].update(v)
        res += print_list("RDMA device", subsystems["infiniband"])
        res += print_list("Net device", subsystems["net"])
        res += print_list("DRM device", subsystems["drm"])
        res += print_list("NVMe device", subsystems["nvme"])

        return res[:-1]


def check_parent(pdev: PCIDevice, parent_type: str):
    if not pdev or not pdev.parent:
        return None
    if pdev.parent.device_type != parent_type:
        return None
    return pdev.parent


class PCITopo(object):
    """Load the PCI topology from sysfs and organize it"""

    def __init__(self, sysfs_dump: Optional[str] = None, virt: Optional[bool] = None):
        if sysfs_dump:
            sysfs_devices = self.__parse_dump(sysfs_dump)
        else:
            sysfs_devices = [SysfsDevice(fn) for fn in os.listdir(DEVDIR)]
        self.devices = self.__load_devices(sysfs_devices)
        self.nvcxs: List[NVCX_Complex] = []
        self.has_cx_dma = any(
            pdev.device_type == "cx_dma" for pdev in self.devices.values()
        )
        self.has_gpu_and_nic = False

        if self.has_cx_dma and virt is not None:
            raise CommandError(
                "--virt / --no-virt is not supported on DMA-based topologies"
            )
        self.virt = virt

        if not self.has_cx_dma:
            found = {
                "cx_switch": False,
                "nvgpu": False,
                "cx_nic": False,
            }
            for pdev in self.devices.values():
                if pdev.device_type not in found.keys():
                    continue
                found[pdev.device_type] = True
            self.has_gpu_and_nic = all(found.values())

            if not self.has_gpu_and_nic:
                return

        for pdev in self.devices.values():
            pdev.finish_loading()
        self.__build_topo()

    def __parse_dump(self, filename: str) -> List[SysfsDevice]:
        res: List[SysfsDevice] = []
        try:
            with open(filename, "rt") as F:
                data = json.load(F)

            if not isinstance(data, list):
                raise ValueError(f"Expected list, got '{type(data).__name__}'")

            num_items = len(data)
            for i, item in enumerate(data):
                if not isinstance(item, dict):
                    raise ValueError(
                        f"Item {i}/{num_items}: Expected dictionary, got '{type(item).__name__}'"
                    )
                try:
                    res.append(SysfsDevice.from_dict(item))
                except Exception as e:
                    raise ValueError(f"Item {i}/{num_items}: {e}") from e
            return res
        except (json.JSONDecodeError, ValueError) as e:
            raise CommandError(f"Invalid sysfs dump file: {e}")
        except (FileNotFoundError, PermissionError) as e:
            raise CommandError(f"Failed to read sysfs dump file: {e}")

    def __load_devices(self, sysfs_devices: List[SysfsDevice]):
        res: Dict[PCIBDF, PCIDevice] = {}
        for sdev in sysfs_devices:
            bdf = to_pcibdf(sdev.id)
            if not bdf:
                continue
            assert bdf not in res
            res[bdf] = PCIDevice(bdf, sdev)
        return res

    def __get_nvcx_complex(self, cx_dma: PCIDevice):
        """Match the topology for the switch complex using a CX DMA function and a
        single GPU. It has two nested switches:

        RP --> SW -> CX_DMA
                  -> SW -> GPU
        """
        assert cx_dma.device_type == "cx_dma"
        if not cx_dma.vpd_v3:
            raise ValueError(f"CX DMA function {cx_dma} does not have a VPD V3 UUID")

        # The DMA and PF are matched using the UUID from the VPD
        cx_pfs = self.vpd_v3s.get(cx_dma.vpd_v3)
        if cx_pfs is None:
            raise ValueError(
                f"CX DMA function {cx_dma} does not have a matching PF, V3 UUID matching failed"
            )
            return None

        # Path from the DMA to the root port
        cx_dma_dsp = check_parent(cx_dma, "cx_switch")
        cx_usp = check_parent(cx_dma_dsp, "cx_switch")
        grace_rp = check_parent(cx_usp, "grace_rp")
        if not grace_rp:
            raise ValueError(
                f"CX DMA function {cx_dma} has an unrecognized upstream path"
            )

        # Path from the GPU to the root port
        nvgpus = [
            pdev for pdev in grace_rp.iterdownstream() if pdev.device_type == "nvgpu"
        ]
        if len(nvgpus) != 1:
            raise ValueError(f"CX DMA function {cx_dma} does not have a nearby GPU")
        nvgpu = nvgpus[0]
        nvgpu_dsp2 = check_parent(nvgpu, "cx_switch")
        nvgpu_usp2 = check_parent(nvgpu_dsp2, "cx_switch")
        nvgpu_dsp1 = check_parent(nvgpu_usp2, "cx_switch")
        if cx_usp != check_parent(nvgpu_dsp1, "cx_switch"):
            raise ValueError(
                f"CX DMA function {cx_dma} has an unrecognized upstream path from the GPU"
            )

        # Sanity check there is nothing unexpected in the topology
        alldevs = {
            cx_dma,
            cx_dma_dsp,
            cx_usp,
            nvgpu,
            nvgpu_dsp2,
            nvgpu_usp2,
            nvgpu_dsp1,
        }
        topodevs = set(grace_rp.iterdownstream())
        if alldevs != topodevs:
            raise ValueError(
                f"CX DMA function {cx_dma} has unexpected PCI devices in the topology"
            )

        return NVCX_DMA_Complex(cx_pfs, cx_dma, nvgpu)

    def __get_nvcx_inline_complex(self, nvgpu: PCIDevice):
        """Match the topology for the inline complex using a GPU and a CX NIC.

        RP --> SW -> CX_NIC
                  -> SW -> GPU
        """
        assert nvgpu.device_type == "nvgpu"

        nvgpu_dsp2 = check_parent(nvgpu, "cx_switch")
        nvgpu_usp2 = check_parent(nvgpu_dsp2, "cx_switch")
        nvgpu_dsp1 = check_parent(nvgpu_usp2, "cx_switch")
        shared_usp1 = check_parent(nvgpu_dsp1, "cx_switch")
        if not shared_usp1:
            raise ValueError(f"GPU {nvgpu} has an unrecognized upstream path")

        for pdev in shared_usp1.iterupstream_path():
            if pdev.device_type == "generic_rp":
                root_port = pdev
                break
        else:
            raise ValueError(
                f"Could not find root port for shared USP {shared_usp1.bdf}"
            )

        for pdev in shared_usp1.iterdownstream():
            if pdev.device_type == "cx_nic":
                cx_nic = pdev
                break
        else:
            raise ValueError(f"GPU {nvgpu} does not have a nearby CX NIC")

        return NVCX_Inline_Complex(root_port, shared_usp1, cx_nic, nvgpu)

    def __auto_detect_virt(self) -> bool:
        """Auto-detect if virtualization will be used on this system"""
        first = self.nvcxs[0].primary_nic.has_ats
        if not all(nvcx.primary_nic.has_ats == first for nvcx in self.nvcxs):
            raise CommandError(
                "Could not auto-detect virtualization: CX NICs have different ATS settings"
            )

        return first

    def __build_topo(self):
        """Collect cross-device information together and build the NVCX_Complex
        objects for the cx_dma functions"""
        self.vpd_v3s: Dict[str, Set[PCIDevice]] = collections.defaultdict(set)
        for pdev in self.devices.values():
            if pdev.parent_bdf and pdev.parent_bdf in self.devices:
                pdev.parent = self.devices[pdev.parent_bdf]
                pdev.parent.children.add(pdev)

            # Many PCI functions may share the same V3
            if pdev.vpd_v3:
                self.vpd_v3s[pdev.vpd_v3].add(pdev)

        if self.has_cx_dma:
            for pdev in self.devices.values():
                if pdev.device_type == "cx_dma":
                    nvcx = self.__get_nvcx_complex(pdev)
                    self.nvcxs.append(nvcx)
        elif self.has_gpu_and_nic:
            for pdev in self.devices.values():
                if pdev.device_type == "nvgpu":
                    nvcx = self.__get_nvcx_inline_complex(pdev)
                    self.nvcxs.append(nvcx)

        if self.has_gpu_and_nic and len(self.nvcxs) > 0:
            if self.virt is None:
                self.virt = self.__auto_detect_virt()

        self.nvcxs.sort(key=lambda x: x.primary_nic.bdf)

    @property
    def supported(self) -> bool:
        """True if the system has a topology that is supported by the rdma_topo tool"""
        return (self.has_cx_dma or self.has_gpu_and_nic) and len(self.nvcxs) > 0

    def compute_acs(self):
        """Return a dictionary of PCI devices and the ACS mask the device should
        have"""
        acs: Dict[PCIDevice, str] = {}
        for nvcx in self.nvcxs:
            acs.update(nvcx.compute_acs(self.virt))

        # Enable, using kernel default, or disable ACS on all other CX
        # bridges and Grace RP based on the virt parameter or if the topology
        # has CX DMA functions.
        #
        # To enable (matches kernel default):
        # bit-4 : ACS Upstream Forwarding
        # bit-3 : ACS P2P Completion Redirect
        # bit-2 : ACS P2P Request Redirect
        # bit-0 : ACS Source Validation
        for pdev in self.devices.values():
            if (
                pdev not in acs
                and ("switch" in pdev.device_type or "grace_rp" in pdev.device_type)
                and pdev.has_acs
            ):
                acs[pdev] = (
                    KERNEL_ACS_ISOLATED if self.has_cx_dma or self.virt else "xx000x0"
                )
        return acs


def add_sysfs_dump_argument(parser):
    parser.add_argument(
        "-F",
        "--sysfs-dump-file",
        action="store",
        default=None,
        dest="sysfs_dump",
        help="Use a file produced by the rdma_topo dump command as input",
    )


def add_virt_argument(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        "--virt",
        action=argparse.BooleanOptionalAction,
        default=None,
        dest="virt",
        help="Whether virtualization will be used on this system. Auto-detect if not set.",
    )


# -------------------------------------------------------------------
def print_list(title: str, items: list[str]):
    if not items:
        return ""
    if len(items) > 1:
        title = title + "s"
    list_str = ", ".join(sorted(items))
    return f"\t{title}: {list_str}\n"


def args_topology(parser):
    parser.add_argument(
        "-j",
        "--json",
        action="store_true",
        dest="json",
        help="Output in machine readable JSON format",
    )
    add_sysfs_dump_argument(parser)


def topo_json(topo: PCITopo):
    import json

    jtop = []
    for nvcx in topo.nvcxs:
        jtop.append(nvcx.to_dict())
    print(json.dumps(jtop, indent=4))


def cmd_topology(args):
    """List the ConnectX NICs in the system with the corresponding NIC
    function, associated GPU, and, optionally, DMA Direct function."""
    topo = PCITopo(args.sysfs_dump, None)
    if not topo.supported:
        raise TOPO_NOT_SUPPORTED

    if args.json:
        return topo_json(topo)

    for nvcx in topo.nvcxs:
        print(nvcx)

cmd_topology.__aliases__ = ("topo",)

# -------------------------------------------------------------------
def update_file(fn: str, new_content: str):
    """Make fn have new_content. If fn already has new_content nothing is
    done."""
    try:
        with open(fn, "rt") as F:
            old = F.read()
        if old == new_content:
            return False
    except FileNotFoundError:
        pass
    with tempfile.NamedTemporaryFile(dir=os.path.dirname(fn), mode="wt") as F:
        F.write(new_content)
        F.flush()
        os.chmod(F.name, 0o644)
        try:
            os.link(F.name, fn)
        except FileExistsError:
            os.unlink(fn)
            os.link(F.name, fn)
    return True


def args_write_grub_acs(parser):
    parser.add_argument(
        "-n",
        "--dry-run",
        action="store_true",
        dest="dry_run",
        help="Output the grub configuration to stdout and make no changes",
    )
    parser.add_argument(
        "--output",
        action="store",
        default="/etc/default/grub.d/config-acs.cfg",
        help="Grub dropin file to use for the kernel command line",
    )
    add_virt_argument(parser)


def cmd_write_grub_acs(args):
    """Generate a grub dropin file to have the kernel commandline set the
    required ACS flags during system boot. This is the recommended way to
    configure ACS on systems but requires a compatible kernel.

    If the system does not have any need of ACS flags the dropin file will be
    removed. This command is intended for Debian style systems with a
    /etc/default/grub.d and update-grub command."""
    topo = PCITopo(None, args.virt)
    if not topo.supported:
        if args.dry_run:
            raise TOPO_NOT_SUPPORTED
        if os.path.exists(args.output):
            print(
                f"W: Found ACS drop-in file {args.output} but the system does not have a supported topology. Deleting file."
            )
            os.unlink(args.output)
        return

    acs = topo.compute_acs()
    config_acs = [
        f"{acs}@{pdev.bdf}"
        for pdev, acs in sorted(acs.items(), key=lambda x: x[0].bdf)
        if acs != KERNEL_ACS_ISOLATED
    ]
    acs_arg = ";".join(config_acs)
    grub_conf = [
        f"# Generated by {sys.argv[0]} do not change. ACS settings for RDMA GPU Direct",
        f'GRUB_CMDLINE_LINUX="$GRUB_CMDLINE_LINUX pci=config_acs=\\"{acs_arg}\\""',
    ]
    grub_conf = "\n".join(grub_conf)

    if args.dry_run:
        print(grub_conf)
        return

    try:
        os.makedirs(os.path.dirname(args.output))
    except FileExistsError:
        pass
    if update_file(args.output, grub_conf + "\n"):
        subprocess.check_call(["update-grub"])


# -------------------------------------------------------------------
def combine_acs(cur_acs, new_acs):
    for idx, val in enumerate(new_acs[::-1]):
        if val == "1":
            cur_acs = cur_acs | (1 << idx)
        elif val == "0":
            cur_acs = cur_acs & (0xFFFF ^ (1 << idx))
    return cur_acs


def args_setpci_acs(parser):
    parser.add_argument(
        "-n",
        "--dry-run",
        action="store_true",
        dest="dry_run",
        help="Output the setpci commands to stdout and make no changes",
    )
    add_virt_argument(parser)


def cmd_setpci_acs(args):
    """Execute a series of set_pci commands that will immediately change the ACS
    settings to the required values. This is compatible with older kernels, but
    is not recommended. The kernel must boot with ACS enabled and the GPU driver
    must have the NVreg_GrdmaPciTopoCheckOverride=1 reg key set to disable
    safety checks that old kernels cannot support.

    NOTE: In this configuration unprivileged userspace can trigger platform RAS
    failures, use with caution!
    """
    topo = PCITopo(None, args.virt)
    if not topo.supported:
        raise TOPO_NOT_SUPPORTED
    acs = topo.compute_acs()
    cmds: List[List[str]] = []
    for pdev, acs in sorted(acs.items(), key=lambda x: x[0].bdf):
        cur_acs = pdev.get_acs_ctrl()
        if cur_acs is None:
            raise CommandError(
                f"Could not read ACS control register for {pdev.device_type} {pdev.bdf}"
            )
        new_acs = combine_acs(cur_acs, acs)
        if new_acs == cur_acs:
            continue

        cmd = ["setpci", "-r", "-s", str(pdev.bdf), f"ECAP_ACS+0x6.w={new_acs:04x}"]
        cmds.append(cmd)
    if args.dry_run:
        for cmd in cmds:
            print(" ".join(cmd))
        return
    for cmd in cmds:
        subprocess.check_call(cmd)


# -------------------------------------------------------------------
def args_check(parser):
    add_sysfs_dump_argument(parser)
    add_virt_argument(parser)


def check_ok(msg: str):
    print(f"OK\t{msg}")


def check_fail(msg: str):
    print(f"FAIL\t{msg}")


def cmd_check(args):
    """Check that the running kernel and PCI environment are setup correctly for
    GPU Direct with ConnectX DMA Direct PCI functions."""
    topo = PCITopo(args.sysfs_dump, args.virt)
    if not topo.supported:
        raise TOPO_NOT_SUPPORTED
    if topo.has_cx_dma:
        check_ok("All ConnectX DMA functions have correct PCI topology")
    elif topo.has_gpu_and_nic:
        check_ok("All NIC/GPU complexes have correct PCI topology")

    fatal = False
    acs = topo.compute_acs()
    for pdev, acs in sorted(acs.items(), key=lambda x: x[0].bdf):
        cur_acs = pdev.get_acs_ctrl()
        if cur_acs is None:
            check_fail(
                f"Could not read ACS control register for {pdev.device_type} {pdev.bdf}"
            )
            fatal = True
            continue
        new_acs = combine_acs(cur_acs, acs)
        if new_acs == cur_acs:
            check_ok(
                f"ACS for {pdev.device_type} {pdev.bdf} has correct values {cur_acs:07b} = {acs}"
            )
        else:
            check_fail(
                f"ACS for {pdev.device_type} {pdev.bdf} has incorrect values {cur_acs:07b} != {acs}, (0x{cur_acs:x} != 0x{new_acs:x})"
            )
            fatal = True

    for nvcx in topo.nvcxs:
        if not nvcx.check(topo.virt):
            fatal = True

    if fatal:
        sys.exit(100)

# -------------------------------------------------------------------
def args_dump(parser):
    pass


def cmd_dump(args) -> None:
    """Dump the PCI topology to a file that can be used as input"""
    sd_json: List[Dict[str, Any]] = []
    for fn in sorted(os.listdir(DEVDIR)):
        sd = SysfsDevice(fn)
        if any(d.match(sd.modalias) for d in dump_ignored):
            continue
        sd_json.append(sd.to_dict())
    json.dump(sd_json, sys.stdout, indent=4)


# -------------------------------------------------------------------
def load_all_commands(name):
    module = importlib.import_module(name)
    for k in dir(module):
        fn = getattr(module, k)
        argsfn = getattr(module, "args_" + k[4:], None)
        if argsfn is None or not k.startswith("cmd_") or not inspect.isfunction(fn):
            continue
        yield (k, fn, argsfn)


def get_cmd_aliases(fn):
    if hasattr(fn, "__aliases__"):
        return fn.__aliases__
    return ()

def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""NVIDIA ConnectX GPU Direct ACS tool for Direct NIC platforms

This tool is used to view and control the PCI Access Control Flags (ACS) related
to the Direct NIC topology on supported NVIDIA platforms with ConnectX and
Blackwell family GPUs.

Direct NIC platforms have a unique multipath PCI topology where the ConnectX
has a main PCI function and a related DMA Direct function linked to the GPU.

This platform requires specific ACS flags in the PCI topology for reliable
operation, this tool helps users generate ACS settings for the local system.
""",
    )
    subparsers = parser.add_subparsers(title="Sub Commands", dest="command")
    subparsers.required = True

    commands = [I for I in load_all_commands(__name__)]
    commands.sort()

    # build sub parsers for all the loaded commands
    for k, fn, argsfn in commands:
        sparser = subparsers.add_parser(
            k[4:].replace("_", "-"), aliases=get_cmd_aliases(fn), help=fn.__doc__
        )
        sparser.required = True
        argsfn(sparser)
        sparser.set_defaults(func=fn)

    try:
        import argcomplete

        argcomplete.autocomplete(parser)
    except ImportError:
        pass

    # argparse will set 'func' to the cmd_* that executes this command
    args = parser.parse_args()
    try:
        args.func(args)
    except CommandError as e:
        print(f"E: {e}")
        sys.exit(100)


main()
