You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
5.9 KiB
101 lines
5.9 KiB
from typing import Any, cast, Optional
|
|
import ctypes
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.helpers import dedup, getenv
|
|
from tinygrad.device import Buffer
|
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
|
from tinygrad.engine.jit import GraphRunner, GraphException
|
|
from tinygrad.ops import Variable
|
|
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
|
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
|
|
|
class MTLIndirectCommandType:
|
|
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
|
|
|
class MTLResourceUsage:
|
|
MTLResourceUsageRead = 0b01
|
|
MTLResourceUsageWrite = 0b10
|
|
|
|
class MetalGraph(GraphRunner):
|
|
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
|
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
|
|
|
# create metal batch exec
|
|
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
|
|
msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
|
msg(icb_descriptor, "setInheritBuffers:", False)
|
|
msg(icb_descriptor, "setInheritPipelineState:", False)
|
|
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
|
|
|
|
self.icb = msg(self.dev.sysdevice, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
|
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
|
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
|
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
|
|
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
|
|
|
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
|
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
|
all_pipelines = []
|
|
for j,ji in enumerate(jit_cache):
|
|
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
|
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
|
|
all_pipelines.append(prg._prg.pipeline_state)
|
|
msg(icb_command, "setComputePipelineState:", prg._prg.pipeline_state)
|
|
for i,b in enumerate(ji.bufs):
|
|
if b is not None and b not in input_rawbuffers:
|
|
msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
|
|
all_resources.append(b._buf.buf)
|
|
for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
|
|
|
global_size, local_size = prg.p.launch_dims(var_vals)
|
|
msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
|
msg(icb_command, "setBarrier")
|
|
|
|
self.all_resources = dedup(all_resources)
|
|
self.all_pipelines = dedup(all_pipelines)
|
|
self.command_buffer: Any = None
|
|
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
|
self.range = to_struct(0, len(jit_cache))
|
|
|
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> Optional[float]:
|
|
|
|
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
|
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
|
|
|
for (j,i),input_idx in self.input_replace.items():
|
|
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
|
|
msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
|
|
input_rawbuffers[input_idx]._buf.offset, i)
|
|
|
|
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
|
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
|
|
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_dims), to_struct(*local_dims))
|
|
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
|
|
|
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
|
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
|
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
|
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
|
|
|
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
|
# this is a O(n) hack to get them used. what should work is:
|
|
#encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
|
|
# but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
|
|
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
|
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
|
for ps in self.all_pipelines:
|
|
msg(encoder, "setComputePipelineState:", ps)
|
|
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
|
|
|
|
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
|
|
msg(encoder, "endEncoding")
|
|
msg(command_buffer, "setLabel:", to_ns_str(f"batched {len(self.jit_cache)}"))
|
|
msg(command_buffer, "commit")
|
|
self.command_buffer = command_buffer
|
|
|
|
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
|
if wait:
|
|
wait_check(command_buffer)
|
|
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
|
return None
|
|
|