from tinygrad import Device, dtypes from tinygrad.helpers import getenv, colorize_float, DEBUG from extra.optimization.helpers import load_worlds, ast_str_to_lin from test.external.fuzz_linearizer import get_fuzz_rawbufs from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.engine.search import bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.tensor import _to_np_dtype from tinygrad.runtime.ops_amd import AMDDevice from contextlib import contextmanager import numpy as np import os, random, statistics am_signal_pages, am_signal_pool, am_devices = [], [], [] amd_signal_pages, amd_signal_pool, amd_devices = [], [], [] def rebind_vfio(pcibus="0000:44:00.0"): print("rebind ", pcibus) os.system("sudo rmmod amdgpu") os.system("sudo modprobe vfio-pci") base = f"/sys/bus/pci/devices/{pcibus}" if os.path.exists(f"{base}/driver"): with open(f"{base}/driver/unbind", "w") as f: f.write(pcibus) with open(f"{base}/driver_override", "w") as f: f.write("vfio-pci") with open("/sys/bus/pci/drivers_probe", "w") as f: f.write(pcibus) os.system("sudo modprobe amdgpu") os.system("rocm-smi --setprofile compute") os.system("rocm-smi --setperflevel high") @contextmanager def run_amd(): global amd_signal_pages, amd_signal_pool, amd_devices AMDDevice.driverless = False AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices = amd_signal_pages, amd_signal_pool, amd_devices yield amd_signal_pages, amd_signal_pool, amd_devices = AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices = [], [], [] @contextmanager def run_am(): global am_signal_pages, am_signal_pool, am_devices AMDDevice.driverless = True AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices = am_signal_pages, am_signal_pool, am_devices yield am_signal_pages, am_signal_pool, am_devices = AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices AMDDevice.signal_pages, AMDDevice.signal_pool, AMDDevice.devices = [], [], [] if __name__ == "__main__": CHECK_CPU = getenv("CHECK_CPU", 0) SEED = getenv("SEED", 42) CNT = getenv("CNT", 7) random.seed(SEED) np.random.seed(SEED) # TODO: NUM=780 is super slow # NUM=1907 is broken on AMD and AM have some mismatches (0 vs 1) # kfd feels so bad when taking gpu out while it's running... Need hacks to rebind it before running. rebind_vfio(pcibus="0000:44:00.0") ast_strs = load_worlds(filter_reduce=False, filter_novariable=True) with run_am(): amdev = Device["AMD:1"] with run_amd(): amddev = Device["AMD"] if CHECK_CPU: cpudev = Device["CPU"] single = getenv("NUM", -1) if single != -1: ast_strs = ast_strs[single:single+1] average_tm_amd, average_tm_am = 0, 0 for num,ast in enumerate(ast_strs): with run_amd(): amdlin = ast_str_to_lin(ast, opts=amddev.renderer) amdlin.apply_opts(hand_coded_optimizations(amdlin)) has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.membufs) amd_prg = CompiledRunner(amdlin.to_program()) amdbufs = bufs_from_lin(amdlin) test_amdbufs = get_fuzz_rawbufs(amdlin) if not has_bf16 else amdbufs if not has_bf16: contents = [buf.as_buffer() for buf in test_amdbufs] with run_am(): rdr = amdev.renderer rdr.device = "AMD:1" amlin = ast_str_to_lin(ast, opts=amdev.renderer) amlin.apply_opts(hand_coded_optimizations(amlin)) am_prg = CompiledRunner(amlin.to_program()) ambufs = bufs_from_lin(amlin) test_ambufs = get_fuzz_rawbufs(amlin) if not has_bf16 else ambufs if not has_bf16: for i,rawbuf in enumerate(test_ambufs): rawbuf.copyin(contents[i]) if CHECK_CPU: cpu_rdr = cpudev.renderer cpu_rdr.device = "CPU" cpulin = ast_str_to_lin(ast, opts=cpu_rdr) cpulin.apply_opts(hand_coded_optimizations(cpulin)) cpu_prg = CompiledRunner(cpulin.to_program()) cpubufs = bufs_from_lin(cpulin) test_cpubufs = get_fuzz_rawbufs(cpulin) if not has_bf16 else ambufs if not has_bf16: for i,rawbuf in enumerate(test_cpubufs): rawbuf.copyin(contents[i]) # warmup tm_amd, tm_am, failed = [], [], False with run_amd(): try: amd_prg(test_amdbufs, {}, wait=True) for i in range(CNT): tm_amd.append(amd_prg(amdbufs, {}, wait=True)) except RuntimeError: print("AMD FAILED") tm_amd = [1e9] failed = True with run_am(): try: am_prg(test_ambufs, {}, wait=True) for i in range(CNT): tm_am.append(am_prg(ambufs, {}, wait=True)) except RuntimeError: print("AM FAILED") tm_am = [1e9] failed = True if CHECK_CPU: cpu_prg(test_cpubufs, {}, wait=True) for i in range(1): cpu_prg(cpubufs, {}, wait=True) if not failed and not has_bf16: with run_amd(): curesult = np.frombuffer(test_amdbufs[0].as_buffer(), _to_np_dtype(test_amdbufs[0].dtype)) with run_am(): amresult = np.frombuffer(test_ambufs[0].as_buffer(), _to_np_dtype(test_ambufs[0].dtype)) if CHECK_CPU: cpuresult = np.frombuffer(test_cpubufs[0].as_buffer(), _to_np_dtype(test_cpubufs[0].dtype)) np.testing.assert_allclose(amresult, cpuresult, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(curesult, cpuresult, rtol=1e-2, atol=1e-2) try: np.testing.assert_allclose(curesult, amresult, rtol=1e-2, atol=1e-2) except AssertionError as e: print("AM and AMD results do not match") print(e) bam = statistics.median(tm_am) bamd = statistics.median(tm_amd) average_tm_amd += bamd average_tm_am += bam ratio = bam/bamd print(f"{average_tm_am/average_tm_amd:5.2f}x -- {num:4d} {colorize_float(ratio)} {bam*1e6:7.2f} vs {bamd*1e6:7.2f} us", amlin.name) if DEBUG > 3 and ratio > 1.04: print(f"AM slower {ratio}", amlin.ast, amlin.applied_opts)