You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
4.4 KiB
86 lines
4.4 KiB
import base64, ctypes, pathlib, tempfile, hashlib
|
|
from tinygrad.device import Compiler
|
|
from tinygrad.helpers import cpu_objdump, system
|
|
import tinygrad.runtime.autogen.mesa as mesa
|
|
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
|
|
try: import tinygrad.runtime.autogen.llvm as llvm
|
|
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
|
|
|
|
def deserialize(enc_src, opts):
|
|
blobreader = mesa.struct_blob_reader()
|
|
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
|
|
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
|
|
|
|
class NIRCompiler(Compiler):
|
|
def __init__(self, cache_key):
|
|
mesa.glsl_type_singleton_init_or_ref()
|
|
super().__init__(cache_key)
|
|
def __del__(self): mesa.glsl_type_singleton_decref()
|
|
|
|
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
|
|
def __init__(self, cache_key="lvp"):
|
|
CPULLVMCompiler.__init__(self)
|
|
NIRCompiler.__init__(self, f"compile_{cache_key}")
|
|
|
|
def __del__(self):
|
|
NIRCompiler.__del__(self)
|
|
CPULLVMCompiler.__del__(self)
|
|
|
|
def compile(self, src) -> bytes:
|
|
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
|
|
gallivm = mesa.gallivm_create(None, mesa.lp_context_ref(ctypes.cast(ctx, ctypes.POINTER(mesa.struct_LLVMOpaqueContext)), True), None).contents
|
|
module, builder = ctypes.cast(gallivm.module, llvm.LLVMModuleRef), ctypes.cast(gallivm.builder, llvm.LLVMBuilderRef)
|
|
|
|
params = mesa.struct_lp_build_tgsi_params(mesa.struct_lp_type(floating=True, sign=True, width=32, length=4),
|
|
resources_type=mesa.lp_build_jit_resources_type(gallivm), mask=ctypes.pointer(mesa.struct_lp_build_mask_context()))
|
|
|
|
pt = llvm.LLVMPointerType(ctypes.cast(params.resources_type, llvm.LLVMTypeRef), 0)
|
|
fn = llvm.LLVMAddFunction(module, shader.contents.info.name, llvm.LLVMFunctionType(llvm.LLVMVoidTypeInContext(ctx), pt, 1, 0))
|
|
llvm.LLVMPositionBuilderAtEnd(builder, llvm.LLVMAppendBasicBlockInContext(ctx, fn, b"entry"))
|
|
|
|
params.consts_ptr = mesa.lp_build_struct_get_ptr2(gallivm, params.resources_type,
|
|
ctypes.cast(llvm.LLVMGetParam(fn, 0), mesa.LLVMValueRef), mesa.LP_JIT_RES_CONSTANTS, b"constants")
|
|
mesa.lp_build_mask_begin(params.mask, gallivm, params.type, mesa.lp_build_one(gallivm, params.type))
|
|
mesa.lp_build_mask_end(params.mask)
|
|
|
|
mesa.lp_build_nir_soa(gallivm, shader, params, None)
|
|
llvm.LLVMBuildRetVoid(builder)
|
|
mesa.gallivm_verify_function(gallivm, ctypes.cast(fn, mesa.LLVMValueRef))
|
|
mesa.lp_passmgr_run(gallivm.passmgr, gallivm.module, ctypes.cast(self.target_machine, mesa.LLVMTargetMachineRef), gallivm.module_name)
|
|
obj_buf = expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(self.target_machine, module, llvm.LLVMObjectFile, err:=cerr(),
|
|
ctypes.pointer(buf:=llvm.LLVMMemoryBufferRef())), err, buf)
|
|
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
|
|
|
|
mesa.gallivm_destroy(gallivm)
|
|
mesa.ralloc_free(shader)
|
|
return obj
|
|
|
|
def disassemble(self, lib: bytes): cpu_objdump(lib)
|
|
|
|
class NAKCompiler(NIRCompiler):
|
|
def __init__(self, arch, warps_per_sm, cache_key="nak"):
|
|
self.arch, self.warps_per_sm = arch, warps_per_sm
|
|
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
|
|
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
|
|
super().__init__(f"compile_{cache_key}_{arch}")
|
|
|
|
def __del__(self):
|
|
mesa.nak_compiler_destroy(self.cc)
|
|
super().__del__()
|
|
|
|
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
|
|
|
|
def compile(self, src) -> bytes:
|
|
shader = deserialize(src, self.nir_options)
|
|
mesa.nak_preprocess_nir(shader, self.cc)
|
|
ret = bytes((out:=mesa.nak_compile_shader(shader, False, self.cc, 0, None).contents).info) + ctypes.string_at(out.code, out.code_size)
|
|
mesa.nak_shader_bin_destroy(out)
|
|
mesa.ralloc_free(shader)
|
|
return ret
|
|
|
|
def disassemble(self, lib: bytes):
|
|
try:
|
|
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinynak_{hashlib.md5(lib).hexdigest()}").as_posix()
|
|
with open(fn, "wb") as f: f.write(lib[ctypes.sizeof(mesa.struct_nak_shader_info):])
|
|
print(system(f"nvdisasm -b SM{self.arch[3:]} {fn}"))
|
|
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
|
|
|