#!/usr/bin/env python3 from __future__ import annotations import argparse, ctypes, struct, hashlib, pickle, code, typing, functools import tinygrad.runtime.autogen.sqtt as sqtt from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent from tinygrad.runtime.ops_amd import ProfileSQTTEvent from tinygrad.helpers import round_up, flatten, all_same from dataclasses import dataclass CHUNK_CLASSES = { sqtt.SQTT_FILE_CHUNK_TYPE_ASIC_INFO: sqtt.struct_sqtt_file_chunk_asic_info, sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DESC: sqtt.struct_sqtt_file_chunk_sqtt_desc, sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DATA: sqtt.struct_sqtt_file_chunk_sqtt_data, sqtt.SQTT_FILE_CHUNK_TYPE_API_INFO: sqtt.struct_sqtt_file_chunk_api_info, sqtt.SQTT_FILE_CHUNK_TYPE_QUEUE_EVENT_TIMINGS: sqtt.struct_sqtt_file_chunk_queue_event_timings, sqtt.SQTT_FILE_CHUNK_TYPE_CLOCK_CALIBRATION: sqtt.struct_sqtt_file_chunk_clock_calibration, sqtt.SQTT_FILE_CHUNK_TYPE_CPU_INFO: sqtt.struct_sqtt_file_chunk_cpu_info, sqtt.SQTT_FILE_CHUNK_TYPE_SPM_DB: sqtt.struct_sqtt_file_chunk_spm_db, sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_DATABASE: sqtt.struct_sqtt_file_chunk_code_object_database, sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_LOADER_EVENTS: sqtt.struct_sqtt_file_chunk_code_object_loader_events, sqtt.SQTT_FILE_CHUNK_TYPE_PSO_CORRELATION: sqtt.struct_sqtt_file_chunk_pso_correlation, } def pretty(val, pad=0) -> str: if isinstance(val, (ctypes.Structure, ctypes.Union)): nl = '\n' # old python versions don't support \ in f-strings return f"{val.__class__.__name__}({nl}{' '*(pad+2)}{(f', {nl}'+' '*(pad+2)).join([f'{field[0]}={pretty(getattr(val, field[0]), pad=pad+2)}' for field in val._fields_])}{nl}{' '*pad})" if isinstance(val, ctypes.Array): return f"[{', '.join(map(pretty, val))}]" if isinstance(val, int) and val >= 1024: return hex(val) return repr(val) @dataclass(frozen=True) class RGPChunk: header: sqtt.Structure data: list[typing.Any]|list[tuple[typing.Any, bytes]]|bytes|None = None def print(self): print(pretty(self.header)) # if isinstance(self.data, bytes): print(repr(self.data)) if isinstance(self.data, list): for dchunk in self.data: if isinstance(dchunk, tuple): print(pretty(dchunk[0])) # print(repr(dchunk[1])) else: print(pretty(dchunk)) # TODO: `def fixup` and true immutability def to_bytes(self, offset:int) -> bytes: cid = self.header.header.chunk_id.type match cid: case _ if cid in {sqtt.SQTT_FILE_CHUNK_TYPE_ASIC_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_CPU_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_API_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DESC}: self.header.header.size_in_bytes = ctypes.sizeof(self.header) return bytes(self.header) case sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DATA: assert isinstance(self.data, bytes) self.header.header.size_in_bytes = ctypes.sizeof(self.header) + len(self.data) self.header.offset = offset+ctypes.sizeof(self.header) self.header.size = len(self.data) return bytes(self.header) + self.data case sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_DATABASE: assert isinstance(self.data, list) data_codb = typing.cast(list[tuple[sqtt.struct_sqtt_code_object_database_record, bytes]], self.data) ret = bytearray() sz = ctypes.sizeof(self.header)+sum([ctypes.sizeof(record_hdr)+round_up(len(record_blob), 4) for record_hdr,record_blob in data_codb]) self.header.header.size_in_bytes = sz self.header.offset = offset self.header.record_count = len(data_codb) self.header.size = sz ret += self.header for record_hdr,record_blob in data_codb: record_hdr.size = round_up(len(record_blob), 4) ret += record_hdr ret += record_blob.ljust(4, b'\x00') return ret case sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_LOADER_EVENTS: assert isinstance(self.data, list) data_lev = typing.cast(list[tuple[sqtt.struct_sqtt_code_object_loader_events_record]], self.data) self.header.header.size_in_bytes = ctypes.sizeof(self.header)+ctypes.sizeof(sqtt.struct_sqtt_code_object_loader_events_record)*len(data_lev) self.header.offset = offset self.header.record_size = ctypes.sizeof(sqtt.struct_sqtt_code_object_loader_events_record) self.header.record_count = len(data_lev) return bytes(self.header) + b''.join(map(bytes, data_lev)) case sqtt.SQTT_FILE_CHUNK_TYPE_PSO_CORRELATION: assert isinstance(self.data, list) data_pso = typing.cast(list[tuple[sqtt.struct_sqtt_pso_correlation_record]], self.data) self.header.header.size_in_bytes = ctypes.sizeof(self.header)+ctypes.sizeof(sqtt.struct_sqtt_pso_correlation_record)*len(data_pso) self.header.offset = offset self.header.record_size = ctypes.sizeof(sqtt.struct_sqtt_pso_correlation_record) self.header.record_count = len(data_pso) return bytes(self.header) + b''.join(map(bytes, data_pso)) case _: raise NotImplementedError(pretty(self.header)) @dataclass(frozen=True) class RGP: header: sqtt.struct_sqtt_file_header chunks: list[RGPChunk] @staticmethod def from_bytes(blob: bytes) -> RGP: file_header = sqtt.struct_sqtt_file_header.from_buffer_copy(blob) assert file_header.magic_number == sqtt.SQTT_FILE_MAGIC_NUMBER and file_header.version_major == sqtt.SQTT_FILE_VERSION_MAJOR i = file_header.chunk_offset chunks = [] while i < len(blob): assert i%4==0, hex(i) hdr = sqtt.struct_sqtt_file_chunk_header.from_buffer_copy(blob, i) cid = hdr.chunk_id.type header: ctypes.Structure match cid: case _ if cid in {sqtt.SQTT_FILE_CHUNK_TYPE_RESERVED, sqtt.SQTT_FILE_CHUNK_TYPE_QUEUE_EVENT_TIMINGS, sqtt.SQTT_FILE_CHUNK_TYPE_CLOCK_CALIBRATION, sqtt.SQTT_FILE_CHUNK_TYPE_SPM_DB}: chunk = None case sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_DATABASE: header = sqtt.struct_sqtt_file_chunk_code_object_database.from_buffer_copy(blob, i) j = header.offset + ctypes.sizeof(header) data: list = [] while j < header.offset + header.size: rec_hdr: ctypes.Structure = sqtt.struct_sqtt_code_object_database_record.from_buffer_copy(blob, j) data.append((rec_hdr, elf:=blob[j+ctypes.sizeof(rec_hdr):j+ctypes.sizeof(rec_hdr)+rec_hdr.size])) assert elf[:4] == b'\x7fELF', repr(elf[:16]) j += ctypes.sizeof(rec_hdr)+rec_hdr.size assert len(data) == header.record_count chunk = RGPChunk(header, data) case sqtt.SQTT_FILE_CHUNK_TYPE_CODE_OBJECT_LOADER_EVENTS: header = sqtt.struct_sqtt_file_chunk_code_object_loader_events.from_buffer_copy(blob, i) data = [sqtt.struct_sqtt_code_object_loader_events_record.from_buffer_copy(blob, header.offset+ctypes.sizeof(header)+j*header.record_size) for j in range(header.record_count)] chunk = RGPChunk(header, data) case sqtt.SQTT_FILE_CHUNK_TYPE_PSO_CORRELATION: header = sqtt.struct_sqtt_file_chunk_pso_correlation.from_buffer_copy(blob, i) data = [sqtt.struct_sqtt_pso_correlation_record.from_buffer_copy(blob, header.offset+ctypes.sizeof(header)+j*header.record_size) for j in range(header.record_count)] chunk = RGPChunk(header, data) case sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DATA: header = sqtt.struct_sqtt_file_chunk_sqtt_data.from_buffer_copy(blob, i) chunk = RGPChunk(header, blob[header.offset:header.offset+header.size]) case _ if cid in {sqtt.SQTT_FILE_CHUNK_TYPE_ASIC_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_CPU_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_API_INFO, sqtt.SQTT_FILE_CHUNK_TYPE_SQTT_DESC}: chunk = RGPChunk(CHUNK_CLASSES[cid].from_buffer_copy(blob, i)) case _: chunk = None print(f"unknown chunk id {cid}") if chunk is not None: chunks.append(chunk) i += hdr.size_in_bytes assert i == len(blob), f'{i} != {len(blob)}' return RGP(file_header, chunks) @staticmethod def from_profile(profile_pickled, device:str|None=None): profile: list[ProfileEvent] = pickle.loads(profile_pickled) device_events = {x.device:x for x in profile if isinstance(x, ProfileDeviceEvent) and x.device.startswith('AMD')} if device is None: if len(device_events) == 0: raise RuntimeError('No supported devices found in profile') if len(device_events) > 1: raise RuntimeError(f"More than one supported device found, select which one to export: {', '.join(device_events.keys())}") _, device_event = device_events.popitem() else: if device not in device_events: raise RuntimeError(f"Device {device} not found in profile, devices in profile: {', '.join(device_events.keys())} ") device_event = device_events[device] sqtt_events = [x for x in profile if isinstance(x, ProfileSQTTEvent) and x.device == device_event.device] if len(sqtt_events) == 0: raise RuntimeError(f"Device {device_event.device} doesn't contain SQTT data") sqtt_itrace_enabled = any([event.itrace for event in sqtt_events]) sqtt_itrace_masked = not all_same([event.itrace for event in sqtt_events]) sqtt_itrace_se_mask = functools.reduce(lambda a,b: a|b, [int(event.itrace) << event.se for event in sqtt_events], 0) if sqtt_itrace_masked else 0 load_events = [x for x in profile if isinstance(x, ProfileProgramEvent) and x.device == device_event.device] loads = [(event.base, struct.unpack(' bytes: ret = bytearray() ret += self.header for chunk in self.chunks: ret += chunk.to_bytes(len(ret)) return bytes(ret) def print(self): print(pretty(self.header)) for chunk in self.chunks: chunk.print() if __name__ == '__main__': parser = argparse.ArgumentParser(prog='rgptool', description='A tool to create (from pickled tinygrad profile), inspect and modify Radeon GPU Profiler files') parser.add_argument('command') parser.add_argument('input') parser.add_argument('-d', '--device') parser.add_argument('-o', '--output') args = parser.parse_args() with open(args.input, 'rb') as fd: input_bytes = fd.read() match args.command: case 'print': rgp = RGP.from_bytes(input_bytes) rgp.print() case 'create': rgp = RGP.from_profile(input_bytes, device=args.device) # rgp.to_bytes() # fixup # rgp.print() case 'repl': rgp = RGP.from_bytes(input_bytes) code.interact(local=locals()) case _: raise RuntimeError(args.command) if args.output is not None: with open(args.output, 'wb+') as fd: fd.write(rgp.to_bytes())