import time , itertools
from tinygrad . uop . ops import Variable
from tinygrad . engine . jit import MultiGraphRunner
from tinygrad . engine . realize import CompiledRunner , BufferXfer , ExecItem
from tinygrad . device import Device , Compiled , Buffer
from tinygrad . runtime . ops_remote import RemoteDevice , RemoteConnection , RemoteRequest , GraphComputeItem , Transfer , GraphAlloc , GraphFree , GraphExec
from tinygrad . runtime . ops_remote import BatchTransfer , Event , Wait
from tinygrad . helpers import unwrap , flatten , dedup
from enum import Enum , auto
from dataclasses import replace
from collections import defaultdict
from typing import cast
class StagingType ( Enum ) : NONE = auto ( ) ; GRAPH = auto ( ) ; TRANSFER = auto ( ) # noqa: E702
def rd ( dev : Compiled ) - > RemoteDevice : return cast ( RemoteDevice , dev )
def dev_key ( dev : RemoteDevice ) : return dev . conn if dev . properties . graph_supports_multi else dev
def map_rawbuf ( rawbuf : Buffer ) : return ( cast ( RemoteDevice , Device [ rawbuf . device ] ) . session , rawbuf . _buf )
class RemoteGraph ( MultiGraphRunner ) :
def __init__ ( self , jit_cache : list [ ExecItem ] , rawbufs : list [ Buffer ] , var_vals : dict [ Variable , int ] ) :
super ( ) . __init__ ( jit_cache , rawbufs , var_vals )
devices = dedup ( flatten ( [ [ Device [ unwrap ( buf ) . device ] for buf in ji . bufs ] for ji in jit_cache ] ) )
c2d = { device . conn : device for device in devices }
self . handle_indexes = { map_rawbuf ( rawbufs [ i ] ) : i for i in sorted ( dedup ( self . input_replace . values ( ) ) ) }
self . template : list [ RemoteRequest ] = [ ]
stagings : dict [ RemoteDevice | RemoteConnection , list [ GraphComputeItem | Transfer ] ] = defaultdict ( list )
clobbered_buffers : set [ Buffer ] = set ( )
cur_staging_type : StagingType = StagingType . NONE
def _flush ( new_staging_type : StagingType , force_break : bool = False ) :
nonlocal cur_staging_type
if cur_staging_type == new_staging_type and not force_break : return
# Pre-sync
if cur_staging_type == StagingType . TRANSFER :
for sdev , ddev in itertools . permutations ( c2d . values ( ) , 2 ) :
self . template . append ( Event ( ddev . session , event := next ( ddev . event_num ) , session = sdev . session ) )
self . template . append ( Wait ( event , session = ddev . session ) )
# Flush
for dev in devices :
dk = dev_key ( dev )
staging = stagings [ dk ]
if not staging : continue
match cur_staging_type :
case StagingType . GRAPH :
bufs = tuple ( map_rawbuf ( rawbufs [ i ] ) for i in sorted ( dedup ( self . input_replace . values ( ) ) ) if dev_key ( rd ( Device [ rawbufs [ i ] . device ] ) ) == dk )
dev . q ( GraphAlloc ( graph_num := next ( dev . graph_num ) , tuple ( staging ) , tuple ( bufs ) , var_vals ) )
self . template . append ( GraphExec ( graph_num , bufs , var_vals , wait = False , session = dev . session ) )
case StagingType . TRANSFER :
st = cast ( list [ Transfer ] , staging )
for host in dedup ( t . dsession . host for t in st ) :
sbuffer_nums = [ ( unwrap ( t . session ) , t . buffer_num ) for t in st if t . dsession . host == host ]
dbuffer_nums = [ ( t . dsession , t . dbuffer_num ) for t in st if t . dsession . host == host ]
self . template . append ( BatchTransfer ( sbuffer_nums , dbuffer_nums , session = dev . session ) )
staging . clear ( )
# Post-sync
if cur_staging_type == StagingType . TRANSFER :
for sdev , ddev in itertools . permutations ( c2d . values ( ) , 2 ) :
self . template . append ( Event ( ddev . session , event := next ( ddev . event_num ) , session = sdev . session ) )
self . template . append ( Wait ( event , session = ddev . session ) )
cur_staging_type = new_staging_type
clobbered_buffers . clear ( )
for ji in jit_cache :
match ji . prg :
case CompiledRunner ( ) :
_flush ( StagingType . GRAPH )
gi = GraphComputeItem ( ji . prg . dev . session , ji . prg . _prg . name , ji . prg . _prg . datahash , tuple ( unwrap ( buf ) . _buf for buf in ji . bufs ) ,
tuple ( ji . prg . p . vars ) , ji . fixedvars , tuple ( ji . prg . p . ins ) , tuple ( ji . prg . p . outs ) ,
tuple ( ji . prg . p . global_size ) if ji . prg . p . global_size is not None else None ,
tuple ( ji . prg . p . local_size ) if ji . prg . p . local_size is not None else None )
stagings [ dev_key ( ji . prg . dev ) ] . append ( gi )
case BufferXfer ( ) :
dest , src = ji . bufs [ 0 : 2 ]
dest_dev , src_dev = cast ( RemoteDevice , Device [ unwrap ( dest ) . device ] ) , cast ( RemoteDevice , Device [ unwrap ( src ) . device ] )
assert dest is not None and src is not None , ji
ti = Transfer ( session = src_dev . session , buffer_num = src . _buf , dsession = dest_dev . session , dbuffer_num = dest . _buf )
if dev_key ( dest_dev ) == dev_key ( src_dev ) :
_flush ( StagingType . GRAPH )
stagings [ dev_key ( src_dev ) ] . append ( ti )
elif dest_dev . conn == src_dev . conn :
_flush ( StagingType . NONE )
self . template . append ( ti )
else :
_flush ( StagingType . TRANSFER , force_break = src in clobbered_buffers )
clobbered_buffers . add ( dest )
stagings [ dev_key ( src_dev ) ] . append ( ti )
case _ : raise NotImplementedError ( ji . prg )
_flush ( StagingType . NONE )
def __del__ ( self ) :
for req in self . template :
match req :
case GraphExec ( ) : RemoteConnection ( unwrap ( req . session ) . host ) . q ( GraphFree ( req . graph_num , session = req . session ) )
def __call__ ( self , rawbufs : list [ Buffer ] , var_vals : dict [ Variable , int ] , wait = False ) :
if wait : st = time . perf_counter ( )
rmap = { orig : map_rawbuf ( rawbufs [ replace_idx ] ) for orig , replace_idx in self . handle_indexes . items ( ) }
for req in self . template :
match req :
case GraphExec ( ) :
req = replace ( req , bufs = tuple ( rmap [ buf ] for buf in req . bufs ) , var_vals = var_vals , wait = wait )
case Transfer ( ) :
if ( req . session , req . buffer_num ) in rmap : req = replace ( req , buffer_num = rmap [ ( req . session , req . buffer_num ) ] [ 1 ] )
if ( req . dsession , req . dbuffer_num ) in rmap : req = replace ( req , dbuffer_num = rmap [ ( req . dsession , req . dbuffer_num ) ] [ 1 ] )
case BatchTransfer ( ) :
req = replace ( req , sbuffer_nums = [ rmap . get ( b , b ) for b in req . sbuffer_nums ] , dbuffer_nums = [ rmap . get ( b , b ) for b in req . dbuffer_nums ] )
case Event ( ) | Wait ( ) :
pass # event number can be reused
case _ : raise NotImplementedError ( req )
RemoteConnection ( unwrap ( req . session ) . host ) . q ( req )
if wait :
RemoteConnection ( unwrap ( req . session ) . host ) . batch_submit ( )
return time . perf_counter ( ) - st