openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
5.9 KiB

1 month ago
from typing import Any, cast, Optional
import ctypes
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, getenv
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.ops import Variable
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
class MTLIndirectCommandType:
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
class MTLResourceUsage:
MTLResourceUsageRead = 0b01
MTLResourceUsageWrite = 0b10
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
msg(icb_descriptor, "setInheritBuffers:", False)
msg(icb_descriptor, "setInheritPipelineState:", False)
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
self.icb = msg(self.dev.sysdevice, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf.buf] if len(self.vars) else []
all_pipelines = []
for j,ji in enumerate(jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
all_pipelines.append(prg._prg.pipeline_state)
msg(icb_command, "setComputePipelineState:", prg._prg.pipeline_state)
for i,b in enumerate(ji.bufs):
if b is not None and b not in input_rawbuffers:
msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(icb_command, "setBarrier")
self.all_resources = dedup(all_resources)
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
self.range = to_struct(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> Optional[float]:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
for (j,i),input_idx in self.input_replace.items():
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
input_rawbuffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_dims), to_struct(*local_dims))
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
# this is a O(n) hack to get them used. what should work is:
#encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
# but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
for ps in self.all_pipelines:
msg(encoder, "setComputePipelineState:", ps)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
msg(encoder, "endEncoding")
msg(command_buffer, "setLabel:", to_ns_str(f"batched {len(self.jit_cache)}"))
msg(command_buffer, "commit")
self.command_buffer = command_buffer
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return None