#!/usr/bin/env python3
import multiprocessing , pickle , difflib , os , threading , json , time , sys , webbrowser , socket , argparse , socketserver , functools , decimal , codecs
from http . server import BaseHTTPRequestHandler
from urllib . parse import parse_qs , urlparse
from typing import Any , TypedDict , Generator
from tinygrad . helpers import colored , getenv , tqdm , unwrap , word_wrap , TRACEMETA
from tinygrad . uop . ops import TrackedGraphRewrite , UOp , Ops , lines , GroupOp , srender , sint
from tinygrad . renderer import ProgramSpec
from tinygrad . device import ProfileEvent , ProfileDeviceEvent , ProfileRangeEvent , ProfileGraphEvent , ProfileGraphEntry , ProfilePointEvent
from tinygrad . dtype import dtypes
uops_colors = { Ops . LOAD : " #ffc0c0 " , Ops . STORE : " #87CEEB " , Ops . CONST : " #e0e0e0 " , Ops . VCONST : " #e0e0e0 " , Ops . REDUCE : " #FF5B5B " ,
Ops . DEFINE_GLOBAL : " #ffe0b0 " , Ops . DEFINE_LOCAL : " #ffe0d0 " , Ops . DEFINE_REG : " #f0ffe0 " , Ops . REDUCE_AXIS : " #FF6B6B " ,
Ops . RANGE : " #c8a0e0 " , Ops . ASSIGN : " #909090 " , Ops . BARRIER : " #ff8080 " , Ops . IF : " #c8b0c0 " , Ops . SPECIAL : " #c0c0ff " ,
Ops . INDEX : " #e8ffa0 " , Ops . WMMA : " #efefc0 " , Ops . VIEW : " #C8F9D4 " , Ops . MULTI : " #f6ccff " , Ops . KERNEL : " #3e7f55 " ,
* * { x : " #D8F9E4 " for x in GroupOp . Movement } , * * { x : " #ffffc0 " for x in GroupOp . ALU } , Ops . THREEFRY : " #ffff80 " , Ops . BUFFER_VIEW : " #E5EAFF " ,
Ops . BLOCK : " #C4A484 " , Ops . BLOCKEND : " #C4A4A4 " , Ops . BUFFER : " #B0BDFF " , Ops . COPY : " #a040a0 " , Ops . FUSE : " #FFa500 " ,
Ops . ALLREDUCE : " #ff40a0 " , Ops . GBARRIER : " #FFC14D " , Ops . MSELECT : " #d040a0 " , Ops . MSTACK : " #d040a0 " }
# VIZ API
# ** Metadata for a track_rewrites scope
def get_metadata ( keys : list [ Any ] , contexts : list [ list [ TrackedGraphRewrite ] ] ) - > list [ dict ] :
ret = [ ]
for k , v in zip ( keys , contexts ) :
steps = [ { " name " : s . name , " loc " : s . loc , " depth " : s . depth , " match_count " : len ( s . matches ) , " code_line " : lines ( s . loc [ 0 ] ) [ s . loc [ 1 ] - 1 ] . strip ( ) } for s in v ]
if isinstance ( k , ProgramSpec ) : ret . append ( { " name " : k . name , " kernel_code " : k . src , " ref " : id ( k . ast ) , " function_name " : k . function_name , " steps " : steps } )
else : ret . append ( { " name " : str ( k ) , " steps " : steps } )
return ret
# ** Complete rewrite details for a graph_rewrite call
class GraphRewriteDetails ( TypedDict ) :
graph : dict # JSON serialized UOp for this rewrite step
uop : str # strigified UOp for this rewrite step
diff : list [ str ] | None # diff of the single UOp that changed
changed_nodes : list [ int ] | None # the changed UOp id + all its parents ids
upat : tuple [ tuple [ str , int ] , str ] | None # [loc, source_code] of the matched UPat
def shape_to_str ( s : tuple [ sint , . . . ] ) : return " ( " + ' , ' . join ( srender ( x ) for x in s ) + " ) "
def mask_to_str ( s : tuple [ tuple [ sint , sint ] , . . . ] ) : return " ( " + ' , ' . join ( shape_to_str ( x ) for x in s ) + " ) "
def uop_to_json ( x : UOp ) - > dict [ int , dict ] :
assert isinstance ( x , UOp )
graph : dict [ int , dict ] = { }
excluded : set [ UOp ] = set ( )
for u in ( toposort := x . toposort ( ) ) :
# always exclude DEVICE/CONST/UNIQUE
if u . op in { Ops . DEVICE , Ops . CONST , Ops . UNIQUE } : excluded . add ( u )
# only exclude CONST VIEW source if it has no other children in the graph
if u . op is Ops . CONST and len ( u . src ) != 0 and all ( cr . op is Ops . CONST for c in u . src [ 0 ] . children if ( cr := c ( ) ) is not None and cr in toposort ) :
excluded . update ( u . src )
for u in toposort :
if u in excluded : continue
argst = codecs . decode ( str ( u . arg ) , " unicode_escape " )
if u . op is Ops . VIEW :
argst = ( " \n " . join ( [ f " { shape_to_str ( v . shape ) } / { shape_to_str ( v . strides ) } " + ( " " if v . offset == 0 else f " / { srender ( v . offset ) } " ) +
( f " \n MASK { mask_to_str ( v . mask ) } " if v . mask is not None else " " ) for v in unwrap ( u . st ) . views ] ) )
label = f " { str ( u . op ) . split ( ' . ' ) [ 1 ] } { ( chr ( 10 ) + word_wrap ( argst . replace ( ' : ' , ' ' ) ) ) if u . arg is not None else ' ' } "
if u . dtype != dtypes . void : label + = f " \n { u . dtype } "
for idx , x in enumerate ( u . src ) :
if x in excluded :
if x . op is Ops . CONST and dtypes . is_float ( u . dtype ) : label + = f " \n CONST { idx } { x . arg : g } "
else : label + = f " \n { x . op . name } { idx } { x . arg } "
try :
if u . op not in { Ops . VIEW , Ops . BUFFER , Ops . KERNEL , Ops . ASSIGN , Ops . COPY , Ops . SINK , * GroupOp . Buffer } and u . st is not None :
label + = f " \n { shape_to_str ( u . shape ) } "
except Exception :
label + = " \n <ISSUE GETTING SHAPE> "
# NOTE: kernel already has metadata in arg
if TRACEMETA > = 2 and u . metadata is not None and u . op is not Ops . KERNEL : label + = " \n " + repr ( u . metadata )
graph [ id ( u ) ] = { " label " : label , " src " : [ id ( x ) for x in u . src if x not in excluded ] , " color " : uops_colors . get ( u . op , " #ffffff " ) ,
" ref " : id ( u . arg . ast ) if u . op is Ops . KERNEL else None , " tag " : u . tag }
return graph
@functools . cache
def _reconstruct ( a : int ) :
op , dtype , src , arg , tag = contexts [ 2 ] [ a ]
arg = type ( arg ) ( _reconstruct ( arg . ast ) , arg . metadata ) if op is Ops . KERNEL else arg
return UOp ( op , dtype , tuple ( _reconstruct ( s ) for s in src ) , arg , tag )
def get_details ( ctx : TrackedGraphRewrite ) - > Generator [ GraphRewriteDetails , None , None ] :
yield { " graph " : uop_to_json ( next_sink := _reconstruct ( ctx . sink ) ) , " uop " : str ( next_sink ) , " changed_nodes " : None , " diff " : None , " upat " : None }
replaces : dict [ UOp , UOp ] = { }
for u0_num , u1_num , upat in tqdm ( ctx . matches ) :
replaces [ u0 := _reconstruct ( u0_num ) ] = u1 = _reconstruct ( u1_num )
try : new_sink = next_sink . substitute ( replaces )
except RecursionError as e : new_sink = UOp ( Ops . NOOP , arg = str ( e ) )
yield { " graph " : ( sink_json := uop_to_json ( new_sink ) ) , " uop " : str ( new_sink ) , " changed_nodes " : [ id ( x ) for x in u1 . toposort ( ) if id ( x ) in sink_json ] ,
" diff " : list ( difflib . unified_diff ( str ( u0 ) . splitlines ( ) , str ( u1 ) . splitlines ( ) ) ) , " upat " : ( upat . location , upat . printable ( ) ) }
if not ctx . bottom_up : next_sink = new_sink
# Profiler API
DevEvent = ProfileRangeEvent | ProfileGraphEntry | ProfilePointEvent
def flatten_events ( profile : list [ ProfileEvent ] ) - > Generator [ tuple [ decimal . Decimal , decimal . Decimal , DevEvent ] , None , None ] :
for e in profile :
if isinstance ( e , ProfileRangeEvent ) : yield ( e . st , e . en , e )
if isinstance ( e , ProfilePointEvent ) : yield ( e . st , e . st , e )
if isinstance ( e , ProfileGraphEvent ) :
for ent in e . ents : yield ( e . sigs [ ent . st_id ] , e . sigs [ ent . en_id ] , ent )
# timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels.
def timeline_layout ( events : list [ tuple [ int , int , float , DevEvent ] ] ) - > dict :
shapes : list [ dict ] = [ ]
levels : list [ int ] = [ ]
for st , et , dur , e in events :
if dur == 0 : continue
# find a free level to put the event
depth = next ( ( i for i , level_et in enumerate ( levels ) if st > = level_et ) , len ( levels ) )
if depth < len ( levels ) : levels [ depth ] = et
else : levels . append ( et )
shapes . append ( { " name " : e . name , " st " : st , " dur " : dur , " depth " : depth } )
return { " shapes " : shapes , " maxDepth " : len ( levels ) }
def mem_layout ( events : list [ tuple [ int , int , float , DevEvent ] ] ) - > dict :
step , peak , mem = 0 , 0 , 0
shps : dict [ int , dict ] = { }
temp : dict [ int , dict ] = { }
timestamps : list [ int ] = [ ]
for st , _ , _ , e in events :
if not isinstance ( e , ProfilePointEvent ) : continue
if e . name == " alloc " :
shps [ e . ref ] = temp [ e . ref ] = { " x " : [ step ] , " y " : [ mem ] , " arg " : e . arg }
timestamps . append ( int ( e . st ) )
step + = 1
mem + = e . arg [ " nbytes " ]
if mem > peak : peak = mem
if e . name == " free " :
timestamps . append ( int ( e . st ) )
step + = 1
mem - = ( removed := temp . pop ( e . ref ) ) [ " arg " ] [ " nbytes " ]
removed [ " x " ] . append ( step )
removed [ " y " ] . append ( removed [ " y " ] [ - 1 ] )
for k , v in temp . items ( ) :
if k > e . ref :
v [ " x " ] + = [ step , step ]
v [ " y " ] + = [ v [ " y " ] [ - 1 ] , v [ " y " ] [ - 1 ] - removed [ " arg " ] [ " nbytes " ] ]
for v in temp . values ( ) :
v [ " x " ] . append ( step )
v [ " y " ] . append ( v [ " y " ] [ - 1 ] )
return { " shapes " : list ( shps . values ( ) ) , " peak " : peak , " timestamps " : timestamps }
def get_profile ( profile : list [ ProfileEvent ] ) :
# start by getting the time diffs
devs = { e . device : ( e . comp_tdiff , e . copy_tdiff if e . copy_tdiff is not None else e . comp_tdiff ) for e in profile if isinstance ( e , ProfileDeviceEvent ) }
# map events per device
dev_events : dict [ str , list [ tuple [ int , int , float , DevEvent ] ] ] = { }
min_ts : int | None = None
max_ts : int | None = None
for ts , en , e in flatten_events ( profile ) :
time_diff = devs [ e . device ] [ e . __dict__ . get ( " is_copy " , False ) ] if e . device in devs else decimal . Decimal ( 0 )
# ProfilePointEvent records perf_counter, offset other events by GPU time diff
st = int ( ts ) if isinstance ( e , ProfilePointEvent ) else int ( ts + time_diff )
et = st if en is None else int ( en + time_diff )
dev_events . setdefault ( e . device , [ ] ) . append ( ( st , et , float ( en - ts ) , e ) )
if min_ts is None or st < min_ts : min_ts = st
if max_ts is None or et > max_ts : max_ts = et
# return layout of per device events
for events in dev_events . values ( ) : events . sort ( key = lambda v : v [ 0 ] )
dev_layout = { k : { " timeline " : timeline_layout ( v ) , " mem " : mem_layout ( v ) } for k , v in dev_events . items ( ) }
return json . dumps ( { " layout " : dev_layout , " st " : min_ts , " et " : max_ts } ) . encode ( " utf-8 " )
# ** HTTP server
class Handler ( BaseHTTPRequestHandler ) :
def do_GET ( self ) :
ret , status_code , content_type = b " " , 200 , " text/html "
if ( fn := { " / " : " index " } . get ( ( url := urlparse ( self . path ) ) . path ) ) :
with open ( os . path . join ( os . path . dirname ( __file__ ) , f " { fn } .html " ) , " rb " ) as f : ret = f . read ( )
elif self . path . startswith ( ( " /assets/ " , " /js/ " ) ) and ' /.. ' not in self . path :
try :
with open ( os . path . join ( os . path . dirname ( __file__ ) , self . path . strip ( ' / ' ) ) , " rb " ) as f : ret = f . read ( )
if url . path . endswith ( " .js " ) : content_type = " application/javascript "
if url . path . endswith ( " .css " ) : content_type = " text/css "
except FileNotFoundError : status_code = 404
elif url . path == " /ctxs " :
if " ctx " in ( query := parse_qs ( url . query ) ) :
kidx , ridx = int ( query [ " ctx " ] [ 0 ] ) , int ( query [ " idx " ] [ 0 ] )
try :
# stream details
self . send_response ( 200 )
self . send_header ( " Content-Type " , " text/event-stream " )
self . send_header ( " Cache-Control " , " no-cache " )
self . end_headers ( )
for r in get_details ( contexts [ 1 ] [ kidx ] [ ridx ] ) :
self . wfile . write ( f " data: { json . dumps ( r ) } \n \n " . encode ( " utf-8 " ) )
self . wfile . flush ( )
self . wfile . write ( " data: END \n \n " . encode ( " utf-8 " ) )
return self . wfile . flush ( )
# pass if client closed connection
except ( BrokenPipeError , ConnectionResetError ) : return
ret , content_type = json . dumps ( ctxs ) . encode ( ) , " application/json "
elif url . path == " /get_profile " and profile_ret is not None : ret , content_type = profile_ret , " application/json "
else : status_code = 404
# send response
self . send_response ( status_code )
self . send_header ( ' Content-Type ' , content_type )
self . send_header ( ' Content-Length ' , str ( len ( ret ) ) )
self . end_headers ( )
return self . wfile . write ( ret )
# ** main loop
def reloader ( ) :
mtime = os . stat ( __file__ ) . st_mtime
while not stop_reloader . is_set ( ) :
if mtime != os . stat ( __file__ ) . st_mtime :
print ( " reloading server... " )
os . execv ( sys . executable , [ sys . executable ] + sys . argv )
time . sleep ( 0.1 )
def load_pickle ( path : str ) :
if path is None or not os . path . exists ( path ) : return None
with open ( path , " rb " ) as f : return pickle . load ( f )
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse ( socketserver . TCPServer ) : allow_reuse_address = True
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --kernels ' , type = str , help = ' Path to kernels ' , default = None )
parser . add_argument ( ' --profile ' , type = str , help = ' Path profile ' , default = None )
args = parser . parse_args ( )
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as s :
if s . connect_ex ( ( ( HOST := " http://127.0.0.1 " ) . replace ( " http:// " , " " ) , PORT := getenv ( " PORT " , 8000 ) ) ) == 0 :
raise RuntimeError ( f " { HOST } : { PORT } is occupied! use PORT= to change. " )
stop_reloader = threading . Event ( )
multiprocessing . current_process ( ) . name = " VizProcess " # disallow opening of devices
st = time . perf_counter ( )
print ( " *** viz is starting " )
contexts , profile = load_pickle ( args . kernels ) , load_pickle ( args . profile )
# NOTE: this context is a tuple of list[keys] and list[values]
ctxs = get_metadata ( * contexts [ : 2 ] ) if contexts is not None else [ ]
profile_ret = get_profile ( profile ) if profile is not None else None
server = TCPServerWithReuse ( ( ' ' , PORT ) , Handler )
reloader_thread = threading . Thread ( target = reloader )
reloader_thread . start ( )
print ( f " *** started viz on { HOST } : { PORT } " )
print ( colored ( f " *** ready in { ( time . perf_counter ( ) - st ) * 1e3 : 4.2f } ms " , " green " ) , flush = True )
if len ( getenv ( " BROWSER " , " " ) ) > 0 : webbrowser . open ( f " { HOST } : { PORT } { ' /profiler ' if contexts is None else ' ' } " )
try : server . serve_forever ( )
except KeyboardInterrupt :
print ( " *** viz is shutting down... " )
stop_reloader . set ( )