from __future__ import annotations import ctypes, collections, time, dataclasses, functools, fcntl, os, hashlib from tinygrad.helpers import mv_address, getenv, DEBUG, temp, fetch from tinygrad.runtime.autogen.am import am from tinygrad.runtime.support.hcq import MMIOInterface from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager from tinygrad.runtime.support.system import PCIDevImplBase from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA AM_DEBUG = getenv("AM_DEBUG", 0) @dataclasses.dataclass(frozen=True) class AMRegister(AMDReg): adev:AMDev def read(self): return self.adev.rreg(self.addr) def read_bitfields(self) -> dict[str, int]: return self.decode(self.read()) def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs)) def update(self, **kwargs): self.write(self.read() & ~self.fields_mask(*kwargs.keys()), **kwargs) class AMFirmware: def __init__(self, adev): self.adev = adev def fmt_ver(hwip): return '_'.join(map(str, adev.ip_ver[hwip])) # Load SOS firmware self.sos_fw = {} blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0) fw_bin = sos_hdr.psp_fw_bin for fw_i in range(sos_hdr.psp_fw_bin_count): fw_bin_desc = am.struct_psp_fw_bin_desc.from_address(ctypes.addressof(fw_bin) + fw_i * ctypes.sizeof(am.struct_psp_fw_bin_desc)) ucode_start_offset = fw_bin_desc.offset_bytes + sos_hdr.header.ucode_array_offset_bytes self.sos_fw[fw_bin_desc.fw_type] = blob[ucode_start_offset:ucode_start_offset+fw_bin_desc.size_bytes] # Load other fw self.ucode_start: dict[str, int] = {} self.descs: list[tuple[list[int], memoryview]] = [] blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0) self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU) # SDMA firmware blob, hdr, hdr_v3 = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0, am.struct_sdma_firmware_header_v3_0) if hdr.header.header_version_major < 3: self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)] self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)] else: self.descs += [self.desc(blob, hdr_v3.header.ucode_array_offset_bytes, hdr_v3.ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)] # PFP, ME, MEC firmware for (fw_name, fw_cnt) in ([('PFP', 1), ('ME', 1)] if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else []) + [('MEC', 1)]: blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0) # Code part self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'))] # Stack stack_fws = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnum}_STACK') for fwnum in range(fw_cnt)] self.descs += [self.desc(blob, hdr.data_offset_bytes, hdr.data_size_bytes, *stack_fws)] self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32) # IMU firmware blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0) imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)] # RLC firmware blob, hdr0, _hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0, am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3) for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]: off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes') self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))] if hdr0.header.header_version_minor == 3: for mem in ['P', 'V']: off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes') self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))] self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)] def load_fw(self, fname:str, *headers): fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/45f59212aebd226c7630aff4b58598967c0c8c91/amdgpu/{fname}", subdir="fw") blob = memoryview(bytearray(fpath.read_bytes())) if AM_DEBUG >= 1: print(f"am {self.adev.devfmt}: loading firmware {fname}: {hashlib.sha256(blob).hexdigest()}") return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers]) def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size]) class AMPageTableEntry: def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q') def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True): assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}" self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000) def entry(self, entry_id:int) -> int: return self.entries[entry_id] def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0 def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000 def is_pte(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id]) class AMMemoryManager(MemoryManager): va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices. def on_range_mapped(self): # Invalidate TLB after mappings. self.dev.gmc.flush_tlb(ip='GC', vmid=0) self.dev.gmc.flush_tlb(ip='MM', vmid=0) class AMDev(PCIDevImplBase): def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None): self.devfmt, self.vram, self.doorbell64, self.mmio, self.dma_regions = devfmt, vram, doorbell, mmio, dma_regions os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions # Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner. if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR) else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT | os.O_CLOEXEC, 0o666) try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.") self._run_discovery() self._build_regs() # AM boot Process: # The GPU being passed can be in one of several states: 1. Not initialized. 2. Initialized by amdgpu. 3. Initialized by AM. # The 1st and 2nd states require a full GPU setup since their states are unknown. The 2nd state also requires a mode1 reset to # reinitialize all components. # # The 3rd state can be set up partially to optimize boot time. In this case, only the GFX and SDMA IPs need to be initialized. # To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for # all blocks that are initialized only during the initial AM boot. # To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag. self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this. self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000005)) and (getenv("AM_RESET", 0) != 1) # Memory manager & firmware self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, pte_cnt=[512, 512, 512, 512], pte_covers=[(1 << ((9 * (3-lv)) + 12)) for lv in range(4)], first_lv=am.AMDGPU_VM_PDB1, first_page_lv=am.AMDGPU_VM_PDB2, va_base=AMMemoryManager.va_allocator.base) self.fw = AMFirmware(self) # Initialize IP blocks self.soc:AM_SOC = AM_SOC(self) self.gmc:AM_GMC = AM_GMC(self) self.ih:AM_IH = AM_IH(self) self.psp:AM_PSP = AM_PSP(self) self.smu:AM_SMU = AM_SMU(self) self.gfx:AM_GFX = AM_GFX(self) self.sdma:AM_SDMA = AM_SDMA(self) # Init sw for all IP blocks for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init_sw() if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0 or self.reg(self.gmc.pf_status_reg("GC")).read() != 0): if DEBUG >= 2: print(f"am {self.devfmt}: Malformed state. Issuing a full reset.") self.partial_boot = False # Init hw for IP blocks where it is needed if not self.partial_boot: if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset() for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu]: ip.init_hw() if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") # Booting done self.is_booting = False # Re-initialize main blocks for ip in [self.gfx, self.sdma]: ip.init_hw() if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") self.smu.set_clocks(level=-1) # last level, max perf. for ip in [self.soc, self.gfx]: ip.set_clockgating_state() self.reg("regSCRATCH_REG7").write(am_version) if DEBUG >= 2: print(f"am {self.devfmt}: boot done") def fini(self): if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing") for ip in [self.sdma, self.gfx]: ip.fini_hw() self.smu.set_clocks(level=0) self.ih.interrupt_handler() os.close(self.lock_fd) def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr def reg(self, reg:str) -> AMRegister: return self.__dict__[reg] def rreg(self, reg:int) -> int: val = self.indirect_rreg(reg * 4) if reg > len(self.mmio) else self.mmio[reg] if AM_DEBUG >= 4 and getattr(self, '_prev_rreg', None) != (reg, val): print(f"am {self.devfmt}: Reading register {reg:#x} with value {val:#x}") self._prev_rreg = (reg, val) return val def wreg(self, reg:int, val:int): if AM_DEBUG >= 4: print(f"am {self.devfmt}: Writing register {reg:#x} with value {val:#x}") if reg > len(self.mmio): self.indirect_wreg(reg * 4, val) else: self.mmio[reg] = val def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int): self.reg(f"{reg_base}{lo_suffix}").write(val & 0xffffffff) self.reg(f"{reg_base}{hi_suffix}").write(val >> 32) def indirect_rreg(self, reg:int) -> int: self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg) return self.reg("regBIF_BX_PF0_RSMU_DATA").read() def indirect_wreg(self, reg:int, val:int): self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg) self.reg("regBIF_BX_PF0_RSMU_DATA").write(val) def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int: start_time = int(time.perf_counter() * 1000) while int(time.perf_counter() * 1000) - start_time < timeout: if ((rval:=reg.read()) & mask) == value: return rval raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}') def _run_discovery(self): # NOTE: Fixed register to query memory size without known ip bases to find the discovery table. # The table is located at the end of VRAM - 64KB and is 10KB in size. mmRCC_CONFIG_MEMSIZE = 0xde3 self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20 tmr_offset, tmr_size = self.vram_size - (64 << 10), (10 << 10) self.bhdr = am.struct_binary_header.from_buffer(bytearray(self.vram.view(tmr_offset, tmr_size)[:])) ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.IP_DISCOVERY].offset) assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}" # Mapping of HW IP to Discovery HW IP hw_id_map = {am.__dict__[x]: int(y) for x,y in am.hw_id_map} self.regs_offset:dict[int, dict[int, tuple]] = collections.defaultdict(dict) self.ip_ver:dict[int, tuple[int, int, int]] = {} for num_die in range(ihdr.num_dies): dhdr = am.struct_die_header.from_address(ctypes.addressof(self.bhdr) + ihdr.die_info[num_die].die_offset) ip_offset = ctypes.addressof(self.bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset for _ in range(dhdr.num_ips): ip = am.struct_ip_v4.from_address(ip_offset) ba = (ctypes.c_uint32 * ip.num_base_address).from_address(ip_offset + 8) for hw_ip in range(1, am.MAX_HWIP): if hw_ip in hw_id_map and hw_id_map[hw_ip] == ip.hw_id: self.regs_offset[hw_ip][ip.instance_number] = tuple(list(ba)) self.ip_ver[hw_ip] = (ip.major, ip.minor, ip.revision) ip_offset += 8 + (8 if ihdr.base_addr_64_bit else 4) * ip.num_base_address gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.GC].offset) self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr) def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): return import_module(prefix, self.ip_ver[hwip], prever_prefix) def _build_regs(self): mods = [("mp", am.MP0_HWIP), ("hdp", am.HDP_HWIP), ("gc", am.GC_HWIP), ("mmhub", am.MMHUB_HWIP), ("osssys", am.OSSSYS_HWIP), ("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)] for prefix, hwip in mods: self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip][0]))) self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP][0])))