from __future__ import annotations
import os , pathlib , struct , ctypes , tempfile , functools , decimal
from typing import Any , Union , cast
from tinygrad . helpers import prod , to_mv , getenv , round_up , cache_dir , T , init_c_struct_t , PROFILE
from tinygrad . device import Compiled , Compiler , CompileError , LRUAllocator , cpu_profile , ProfileDeviceEvent , ProfileRangeEvent
from tinygrad . renderer . cstyle import MetalRenderer
class objc_id ( ctypes . c_void_p ) : # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
def __hash__ ( self ) : return hash ( self . value )
def __eq__ ( self , other ) : return self . value == other . value
class objc_instance ( objc_id ) : # method with name "new", "alloc" should be freed after use
def __del__ ( self ) : msg ( self , " release " )
@functools . lru_cache ( None )
def sel ( name : str ) : return libobjc . sel_registerName ( name . encode ( ) )
class MTLResourceOptions :
MTLResourceCPUCacheModeDefaultCache = 0
MTLResourceStorageModeShared = 0 << 4
class MTLPipelineOption :
MTLPipelineOptionNone = 0
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
libobjc = ctypes . CDLL ( " /usr/lib/libobjc.dylib " )
libmetal = ctypes . CDLL ( " /System/Library/Frameworks/Metal.framework/Metal " )
compiler = ctypes . CDLL ( " /System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler " )
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes . CDLL ( " /System/Library/Frameworks/CoreGraphics.framework/CoreGraphics " )
libdispatch = ctypes . CDLL ( " /usr/lib/libSystem.dylib " ) # libdispatch is part of libSystem on mac
libobjc . objc_getClass . restype = objc_id
libobjc . sel_registerName . restype = objc_id
libmetal . MTLCreateSystemDefaultDevice . restype = objc_instance
compiler . MTLCodeGenServiceCreate . restype = ctypes . c_void_p
libdispatch . dispatch_data_create . restype = objc_instance
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
def msg ( ptr : objc_id , selector : str , / , * args : Any , restype : type [ T ] = objc_id ) - > T : # type: ignore [assignment]
sender = libobjc [ " objc_msgSend " ] # Using attribute access returns a new reference so setting restype is safe
sender . restype = restype
return sender ( ptr , sel ( selector ) , * args )
def to_ns_str ( s : str ) : return msg ( libobjc . objc_getClass ( b " NSString " ) , " stringWithUTF8String: " , s . encode ( ) , restype = objc_instance )
def from_ns_str ( s ) : return bytes ( msg ( s , " UTF8String " , restype = ctypes . c_char_p ) ) . decode ( )
def to_struct ( * t : int , _type : type = ctypes . c_ulong ) : return init_c_struct_t ( tuple ( [ ( f " field { i } " , _type ) for i in range ( len ( t ) ) ] ) ) ( * t )
def wait_check ( cbuf : Any ) :
msg ( cbuf , " waitUntilCompleted " )
error_check ( msg ( cbuf , " error " , restype = objc_instance ) )
def cmdbuf_label ( cbuf : objc_id ) - > str : return from_ns_str ( msg ( cbuf , " label " , restype = objc_id ) )
def cmdbuf_st_time ( cbuf : objc_id ) - > float : return cast ( float , msg ( cbuf , " GPUStartTime " , restype = ctypes . c_double ) )
def cmdbuf_en_time ( cbuf : objc_id ) - > float : return cast ( float , msg ( cbuf , " GPUEndTime " , restype = ctypes . c_double ) )
def error_check ( error : objc_instance , error_constructor : type [ Exception ] = RuntimeError ) :
if error . value is None : return None
raise error_constructor ( from_ns_str ( msg ( error , " localizedDescription " , restype = objc_instance ) ) )
def metal_src_to_library ( device : MetalDevice , src : str ) - > objc_instance :
options = msg ( libobjc . objc_getClass ( b " MTLCompileOptions " ) , " new " , restype = objc_instance )
msg ( options , " setFastMathEnabled: " , getenv ( " METAL_FAST_MATH " ) )
library = msg ( device . sysdevice , " newLibraryWithSource:options:error: " , to_ns_str ( src ) , options ,
ctypes . byref ( compileError := objc_instance ( ) ) , restype = objc_instance )
error_check ( compileError , CompileError )
return library
class MetalCompiler ( Compiler ) :
def __init__ ( self ) :
self . cgs = ctypes . c_void_p ( compiler . MTLCodeGenServiceCreate ( b " tinygrad " ) )
super ( ) . __init__ ( " compile_metal_direct " )
def __reduce__ ( self ) : return ( MetalCompiler , ( ) ) # force pickle to create new instance for each multiprocessing fork
def compile ( self , src : str ) - > bytes :
ret : Union [ Exception , bytes ] = CompileError ( " MTLCodeGenServiceBuildRequest returned without calling the callback " )
@ctypes . CFUNCTYPE ( None , ctypes . c_void_p , ctypes . c_int32 , ctypes . c_void_p , ctypes . c_size_t , ctypes . c_char_p )
def callback ( blockptr , error , dataPtr , dataLen , errorMessage ) :
nonlocal ret
if error == 0 :
reply = bytes ( to_mv ( dataPtr , dataLen ) )
# offset from beginning to data = header size + warning size
ret = reply [ sum ( struct . unpack ( ' <LL ' , reply [ 8 : 16 ] ) ) : ]
else :
ret = CompileError ( errorMessage . decode ( ) )
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
params = f ' -fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path= " { cache_dir } " '
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
src_padded , params_padded = src . encode ( ) + b ' \x00 ' * ( round_up ( len ( src ) + 1 , 4 ) - len ( src ) ) , params . encode ( ) + b ' \x00 '
request = struct . pack ( ' <QQ ' , len ( src_padded ) , len ( params_padded ) ) + src_padded + params_padded
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
# argument and pretend it's a normal callback
compiler . MTLCodeGenServiceBuildRequest ( self . cgs , None , REQUEST_TYPE_COMPILE , request , len ( request ) , ctypes . byref ( callback , - 0x10 ) )
if isinstance ( ret , Exception ) : raise ret
assert ret [ : 4 ] == b " MTLB " and ret [ - 4 : ] == b " ENDT " , f " Invalid Metal library. { ret !r} "
return ret
def disassemble ( self , lib : bytes ) :
with tempfile . NamedTemporaryFile ( delete = True ) as shader :
shader . write ( lib )
shader . flush ( )
ret = os . system ( f " cd { pathlib . Path ( __file__ ) . parents [ 2 ] } /extra/disassemblers/applegpu && python3 compiler_explorer.py { shader . name } " )
if ret : print ( " Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu " )
class MetalProgram :
def __init__ ( self , dev : MetalDevice , name : str , lib : bytes ) :
self . dev , self . name , self . lib = dev , name , lib
if lib [ : 4 ] == b " MTLB " :
# binary metal library
data = libdispatch . dispatch_data_create ( lib , len ( lib ) , None , None )
self . library = msg ( self . dev . sysdevice , " newLibraryWithData:error: " , data , ctypes . byref ( error_lib := objc_instance ( ) ) , restype = objc_instance )
error_check ( error_lib )
else :
# metal source. rely on OS caching
try : self . library = metal_src_to_library ( self . dev , lib . decode ( ) )
except CompileError as e : raise RuntimeError from e
self . fxn = msg ( self . library , " newFunctionWithName: " , to_ns_str ( name ) , restype = objc_instance )
descriptor = msg ( libobjc . objc_getClass ( b " MTLComputePipelineDescriptor " ) , " new " , restype = objc_instance )
msg ( descriptor , " setComputeFunction: " , self . fxn )
msg ( descriptor , " setSupportIndirectCommandBuffers: " , True )
self . pipeline_state = msg ( self . dev . sysdevice , " newComputePipelineStateWithDescriptor:options:reflection:error: " ,
descriptor , MTLPipelineOption . MTLPipelineOptionNone , None , ctypes . byref ( error_pipeline_creation := objc_instance ( ) ) , restype = objc_instance )
error_check ( error_pipeline_creation )
def __call__ ( self , * bufs , global_size : tuple [ int , int , int ] = ( 1 , 1 , 1 ) , local_size : tuple [ int , int , int ] = ( 1 , 1 , 1 ) , vals : tuple [ int , . . . ] = ( ) , wait = False ) :
max_total_threads = msg ( self . pipeline_state , " maxTotalThreadsPerThreadgroup " , restype = ctypes . c_ulong )
if prod ( local_size ) > cast ( int , max_total_threads ) :
exec_width = msg ( self . pipeline_state , " threadExecutionWidth " , restype = ctypes . c_ulong )
memory_length = msg ( self . pipeline_state , " staticThreadgroupMemoryLength " , restype = ctypes . c_ulong )
raise RuntimeError ( f " local size { local_size } bigger than { max_total_threads } with exec width { exec_width } memory length { memory_length } " )
command_buffer = msg ( self . dev . mtl_queue , " commandBuffer " , restype = objc_instance )
encoder = msg ( command_buffer , " computeCommandEncoder " , restype = objc_instance )
msg ( encoder , " setComputePipelineState: " , self . pipeline_state )
for i , a in enumerate ( bufs ) : msg ( encoder , " setBuffer:offset:atIndex: " , a . buf , a . offset , i )
for i , a in enumerate ( vals , start = len ( bufs ) ) : msg ( encoder , " setBytes:length:atIndex: " , bytes ( ctypes . c_int ( a ) ) , 4 , i )
msg ( encoder , " dispatchThreadgroups:threadsPerThreadgroup: " , to_struct ( * global_size ) , to_struct ( * local_size ) )
msg ( encoder , " endEncoding " )
msg ( command_buffer , " setLabel: " , to_ns_str ( self . name ) )
msg ( command_buffer , " commit " )
self . dev . mtl_buffers_in_flight . append ( command_buffer )
if wait :
wait_check ( command_buffer )
return cmdbuf_en_time ( command_buffer ) - cmdbuf_st_time ( command_buffer )
class MetalBuffer :
def __init__ ( self , buf : Any , size : int , offset = 0 ) : self . buf , self . size , self . offset = buf , size , offset
class MetalAllocator ( LRUAllocator ) :
def __init__ ( self , dev : MetalDevice ) :
self . dev : MetalDevice = dev
super ( ) . __init__ ( )
def _alloc ( self , size : int , options ) - > MetalBuffer :
# Buffer is explicitly released in _free() rather than garbage collected via reference count
ret = msg ( self . dev . sysdevice , " newBufferWithLength:options: " , ctypes . c_ulong ( size ) , MTLResourceOptions . MTLResourceStorageModeShared ,
restype = objc_id )
if ret . value is None : raise MemoryError ( f " Metal OOM while allocating { size =} " )
return MetalBuffer ( ret , size )
def _free ( self , opaque : MetalBuffer , options ) : msg ( opaque . buf , " release " )
def _transfer ( self , dest : MetalBuffer , src : MetalBuffer , sz : int , src_dev : MetalDevice , dest_dev : MetalDevice ) :
dest_dev . synchronize ( )
src_command_buffer = msg ( src_dev . mtl_queue , " commandBuffer " , restype = objc_instance )
encoder = msg ( src_command_buffer , " blitCommandEncoder " , restype = objc_instance )
msg ( encoder , " copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size: " , src . buf , ctypes . c_ulong ( src . offset ) ,
dest . buf , ctypes . c_ulong ( dest . offset ) , ctypes . c_ulong ( sz ) )
msg ( encoder , " endEncoding " )
if src_dev != dest_dev :
msg ( src_command_buffer , " encodeSignalEvent:value: " , src_dev . timeline_signal , src_dev . timeline_value )
dest_command_buffer = msg ( dest_dev . mtl_queue , " commandBuffer " , restype = objc_instance )
msg ( dest_command_buffer , " encodeWaitForEvent:value: " , src_dev . timeline_signal , src_dev . timeline_value )
msg ( dest_command_buffer , " commit " )
dest_dev . mtl_buffers_in_flight . append ( dest_command_buffer )
src_dev . timeline_value + = 1
msg ( src_command_buffer , " commit " )
src_dev . mtl_buffers_in_flight . append ( src_command_buffer )
def _cp_mv ( self , dst , src , prof_desc ) :
with cpu_profile ( prof_desc , self . dev . device , is_copy = True ) : dst [ : ] = src
def _as_buffer ( self , src : MetalBuffer ) - > memoryview :
self . dev . synchronize ( )
return to_mv ( cast ( int , msg ( src . buf , " contents " , restype = objc_id ) . value ) , src . size + src . offset ) [ src . offset : ]
def _copyin ( self , dest : MetalBuffer , src : memoryview ) : self . _cp_mv ( self . _as_buffer ( dest ) , src , " CPU -> METAL " )
def _copyout ( self , dest : memoryview , src : MetalBuffer ) : self . _cp_mv ( dest , self . _as_buffer ( src ) , " METAL -> CPU " )
def _offset ( self , buf : MetalBuffer , size : int , offset : int ) : return MetalBuffer ( buf . buf , size , offset )
class MetalDevice ( Compiled ) :
def __init__ ( self , device : str ) :
self . sysdevice = libmetal . MTLCreateSystemDefaultDevice ( )
self . mtl_queue = msg ( self . sysdevice , " newCommandQueueWithMaxCommandBufferCount: " , 1024 , restype = objc_instance )
if self . mtl_queue is None : raise RuntimeError ( " Cannot allocate a new command queue " )
self . mtl_buffers_in_flight : list [ Any ] = [ ]
self . timeline_signal = msg ( self . sysdevice , " newSharedEvent " , restype = objc_instance )
self . timeline_value = 0
Compiled . profile_events + = [ ProfileDeviceEvent ( device ) ]
from tinygrad . runtime . graph . metal import MetalGraph
super ( ) . __init__ ( device , MetalAllocator ( self ) , MetalRenderer ( ) , MetalCompiler ( ) if getenv ( " METAL_DIRECT " , 1 ) else Compiler ( ) ,
functools . partial ( MetalProgram , self ) , MetalGraph )
def synchronize ( self ) :
for cbuf in self . mtl_buffers_in_flight :
wait_check ( cbuf )
st , en = decimal . Decimal ( cmdbuf_st_time ( cbuf ) ) * 1000000 , decimal . Decimal ( cmdbuf_en_time ( cbuf ) ) * 1000000
if PROFILE : Compiled . profile_events + = [ ProfileRangeEvent ( self . device , cmdbuf_label ( cbuf ) , st , en , is_copy = False ) ]
self . mtl_buffers_in_flight . clear ( )