You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
8.8 KiB
141 lines
8.8 KiB
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses
|
|
from typing import cast
|
|
from tinygrad.helpers import round_up, to_mv, getenv, OSX
|
|
from tinygrad.runtime.autogen import libc, vfio
|
|
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQCompiled, HCQBuffer
|
|
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping
|
|
|
|
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
|
|
|
class _System:
|
|
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False, data:bytes|None=None) -> tuple[int, list[int]]:
|
|
assert not contiguous or size <= (2 << 20), "Contiguous allocation is only supported for sizes up to 2MB"
|
|
flags = (libc.MAP_HUGETLB if contiguous and (size:=round_up(size, mmap.PAGESIZE)) > 0x1000 else 0) | (MAP_FIXED if vaddr else 0)
|
|
va = FileIOInterface.anon_mmap(vaddr, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|mmap.MAP_ANONYMOUS|MAP_POPULATE|MAP_LOCKED|flags, 0)
|
|
|
|
if data is not None: to_mv(va, len(data))[:] = data
|
|
|
|
# Read pagemap to get the physical address of each page. The pages are locked.
|
|
self.pagemap().seek(va // mmap.PAGESIZE * 8)
|
|
return va, [(x & ((1<<55) - 1)) * mmap.PAGESIZE for x in array.array('Q', self.pagemap().read(size//mmap.PAGESIZE*8, binary=True))]
|
|
|
|
def pci_scan_bus(self, target_vendor:int, target_devices:list[int]) -> list[str]:
|
|
result = []
|
|
for pcibus in FileIOInterface("/sys/bus/pci/devices").listdir():
|
|
vendor = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16)
|
|
device = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16)
|
|
if vendor == target_vendor and device in target_devices: result.append(pcibus)
|
|
return sorted(result)
|
|
|
|
@functools.cache
|
|
def pagemap(self) -> FileIOInterface:
|
|
if FileIOInterface(reloc_sysfs:="/proc/sys/vm/compact_unevictable_allowed", os.O_RDONLY).read()[0] != "0":
|
|
os.system(cmd:=f"sudo sh -c 'echo 0 > {reloc_sysfs}'")
|
|
assert FileIOInterface(reloc_sysfs, os.O_RDONLY).read()[0] == "0", f"Failed to disable migration of locked pages. Please run {cmd} manually."
|
|
return FileIOInterface("/proc/self/pagemap", os.O_RDONLY)
|
|
|
|
@functools.cache
|
|
def vfio(self) -> FileIOInterface|None:
|
|
try:
|
|
if not FileIOInterface.exists("/sys/module/vfio"): os.system("sudo modprobe vfio-pci disable_idle_d3=1")
|
|
|
|
FileIOInterface("/sys/module/vfio/parameters/enable_unsafe_noiommu_mode", os.O_RDWR).write("1")
|
|
vfio_fd = FileIOInterface("/dev/vfio/vfio", os.O_RDWR)
|
|
vfio.VFIO_CHECK_EXTENSION(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU)
|
|
|
|
return vfio_fd
|
|
except OSError: return None
|
|
|
|
System = _System()
|
|
|
|
class PCIDevice:
|
|
def __init__(self, pcibus:str, bars:list[int], resize_bars:list[int]|None=None):
|
|
self.pcibus, self.irq_poller = pcibus, None
|
|
|
|
if FileIOInterface.exists(f"/sys/bus/pci/devices/{self.pcibus}/driver"):
|
|
FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver/unbind", os.O_WRONLY).write(self.pcibus)
|
|
|
|
for i in resize_bars or []:
|
|
supported_sizes = int(FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{i}_resize", os.O_RDONLY).read(), 16)
|
|
try: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{i}_resize", os.O_RDWR).write(str(supported_sizes.bit_length() - 1))
|
|
except OSError as e: raise RuntimeError(f"Cannot resize BAR {i}: {e}. Ensure the resizable BAR option is enabled on your system.") from e
|
|
|
|
if getenv("VFIO", 0) and (vfio_fd:=System.vfio()) is not None:
|
|
FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver_override", os.O_WRONLY).write("vfio-pci")
|
|
FileIOInterface("/sys/bus/pci/drivers_probe", os.O_WRONLY).write(self.pcibus)
|
|
iommu_group = FileIOInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1]
|
|
|
|
self.vfio_group = FileIOInterface(f"/dev/vfio/noiommu-{iommu_group}", os.O_RDWR)
|
|
vfio.VFIO_GROUP_SET_CONTAINER(self.vfio_group, ctypes.c_int(vfio_fd.fd))
|
|
|
|
with contextlib.suppress(OSError): vfio.VFIO_SET_IOMMU(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU) # set iommu works only once for the fd.
|
|
self.vfio_dev = FileIOInterface(fd=vfio.VFIO_GROUP_GET_DEVICE_FD(self.vfio_group, ctypes.create_string_buffer(self.pcibus.encode())))
|
|
|
|
self.irq_fd = FileIOInterface.eventfd(0, 0)
|
|
self.irq_poller = select.poll()
|
|
self.irq_poller.register(self.irq_fd.fd, select.POLLIN)
|
|
|
|
irqs = vfio.struct_vfio_irq_set(index=vfio.VFIO_PCI_MSI_IRQ_INDEX, flags=vfio.VFIO_IRQ_SET_DATA_EVENTFD|vfio.VFIO_IRQ_SET_ACTION_TRIGGER,
|
|
argsz=ctypes.sizeof(vfio.struct_vfio_irq_set), count=1, data=(ctypes.c_int * 1)(self.irq_fd.fd))
|
|
vfio.VFIO_DEVICE_SET_IRQS(self.vfio_dev, irqs)
|
|
else: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/enable", os.O_RDWR).write("1")
|
|
|
|
self.cfg_fd = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/config", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC)
|
|
self.bar_fds = {b: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{b}", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC) for b in bars}
|
|
|
|
bar_info = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource", os.O_RDONLY).read().splitlines()
|
|
self.bar_info = {j:(int(start,16), int(end,16), int(flgs,16)) for j,(start,end,flgs) in enumerate(l.split() for l in bar_info)}
|
|
|
|
def read_config(self, offset:int, size:int): return int.from_bytes(self.cfg_fd.read(size, binary=True, offset=offset), byteorder='little')
|
|
def write_config(self, offset:int, value:int, size:int): self.cfg_fd.write(value.to_bytes(size, byteorder='little'), binary=True, offset=offset)
|
|
def map_bar(self, bar:int, off:int=0, addr:int=0, size:int|None=None, fmt='B') -> MMIOInterface:
|
|
fd, sz = self.bar_fds[bar], size or (self.bar_info[bar][1] - self.bar_info[bar][0] + 1)
|
|
libc.madvise(loc:=fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz, libc.MADV_DONTFORK)
|
|
return MMIOInterface(loc, sz, fmt=fmt)
|
|
|
|
class PCIDevImplBase:
|
|
mm: MemoryManager
|
|
|
|
@dataclasses.dataclass
|
|
class PCIAllocationMeta: owner:HCQCompiled; mapped_devs:list; mapping:VirtMapping; has_cpu_mapping:bool # noqa: E702
|
|
|
|
class PCIIfaceBase:
|
|
dev_impl:PCIDevImplBase
|
|
gpus:list[str] = []
|
|
|
|
def __init__(self, dev, dev_id, vendor, devices, bars, vram_bar, va_start, va_size):
|
|
if len(PCIIfaceBase.gpus) == 0:
|
|
PCIIfaceBase.gpus = System.pci_scan_bus(vendor, devices)
|
|
visible_devices = [int(x) for x in (getenv('VISIBLE_DEVICES', '')).split(',') if x.strip()]
|
|
PCIIfaceBase.gpus = [PCIIfaceBase.gpus[x] for x in visible_devices] if visible_devices else PCIIfaceBase.gpus
|
|
|
|
# Acquire va range to avoid collisions.
|
|
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
|
|
self.pci_dev, self.dev, self.vram_bar = PCIDevice(PCIIfaceBase.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
|
|
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar][0]
|
|
|
|
def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, **kwargs) -> HCQBuffer:
|
|
if host or (uncached and cpu_access): # host or gtt-like memory.
|
|
vaddr = self.dev_impl.mm.alloc_vaddr(size:=round_up(size, mmap.PAGESIZE), align=mmap.PAGESIZE)
|
|
paddrs = [(paddr, mmap.PAGESIZE) for paddr in System.alloc_sysmem(size, vaddr=vaddr, contiguous=contiguous)[1]]
|
|
mapping = self.dev_impl.mm.map_range(vaddr, size, paddrs, system=True, snooped=True, uncached=True)
|
|
return HCQBuffer(vaddr, size, meta=PCIAllocationMeta(self.dev, [self.dev], mapping, has_cpu_mapping=True),
|
|
view=MMIOInterface(mapping.va_addr, size, fmt='B'))
|
|
|
|
mapping = self.dev_impl.mm.valloc(size:=round_up(size, 4 << 10), uncached=uncached, contiguous=cpu_access)
|
|
if cpu_access: self.pci_dev.map_bar(bar=self.vram_bar, off=mapping.paddrs[0][0], addr=mapping.va_addr, size=mapping.size)
|
|
return HCQBuffer(mapping.va_addr, size, view=MMIOInterface(mapping.va_addr, size, fmt='B') if cpu_access else None,
|
|
meta=PCIAllocationMeta(self.dev, [self.dev], mapping, has_cpu_mapping=cpu_access))
|
|
|
|
def free(self, b:HCQBuffer):
|
|
for dev in b.meta.mapped_devs[1:]: dev.dev_iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
|
|
if not b.meta.mapping.system: self.dev_impl.mm.vfree(b.meta.mapping)
|
|
if b.meta.owner == self.dev and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)
|
|
|
|
def map(self, b:HCQBuffer):
|
|
# Check if the memory is already mapped on this device
|
|
if self.dev in b.meta.mapped_devs: return
|
|
b.meta.mapped_devs.append(self.dev)
|
|
|
|
paddrs = [(paddr if b.meta.mapping.system else (paddr+b.meta.owner.dev_iface.p2p_base_addr), size) for paddr,size in b.meta.mapping.paddrs]
|
|
self.dev_impl.mm.map_range(cast(int, b.va_addr), b.size, paddrs, system=True, snooped=b.meta.mapping.snooped, uncached=b.meta.mapping.uncached)
|
|
|